diff --git a/main.go b/main.go index 5bdb8496..66a0c0cb 100644 --- a/main.go +++ b/main.go @@ -4,11 +4,15 @@ package main import ( "encoding/json" + "fmt" "log" "net/http" + "sort" + "strings" "time" "charm.land/catwalk/internal/providers" + "charm.land/catwalk/pkg/catwalk" "github.com/charmbracelet/x/etag" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -69,6 +73,109 @@ func providersHandler(w http.ResponseWriter, r *http.Request) { } } +func providersSpecificHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + path := strings.TrimPrefix(r.URL.Path, "/v2/providers/") + providerID := strings.TrimSuffix(path, ".md") + providerID = strings.TrimSuffix(providerID, "/") + + if providerID == "" { + http.Error(w, "Provider ID is required", http.StatusBadRequest) + return + } + + isMarkdown := strings.HasSuffix(r.URL.Path, ".md") + + allProviders := providers.GetAll() + + var foundProvider *catwalk.Provider + for _, provider := range allProviders { + if string(provider.ID) == providerID { + foundProvider = &provider + break + } + } + + if foundProvider == nil { + http.Error(w, "Provider not found", http.StatusNotFound) + return + } + + if isMarkdown { + // Return as markdown table + w.Header().Set("Content-Type", "text/markdown") + renderProviderMarkdown(w, *foundProvider, r) + } else { + // Return as JSON + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(foundProvider); err != nil { + log.Printf("Error encoding provider response: %v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + } +} + +func yesNo(b bool) string { + if b { + return "yes" + } + return "no" +} + +func renderProviderMarkdown(w http.ResponseWriter, provider catwalk.Provider, r *http.Request) { + var buf strings.Builder + + fmt.Fprintf(&buf, "# Provider: %s\n\n", provider.Name) + fmt.Fprintf(&buf, "| Field | Value |\n") + fmt.Fprintf(&buf, "|-------|-------|\n") + fmt.Fprintf(&buf, "| ID | `%s` |\n", provider.ID) + fmt.Fprintf(&buf, "| Type | `%s` |\n", provider.Type) + fmt.Fprintf(&buf, "| API Endpoint | `%s` |\n", provider.APIEndpoint) + fmt.Fprintf(&buf, "| Default Large Model ID | `%s` |\n", provider.DefaultLargeModelID) + fmt.Fprintf(&buf, "| Default Small Model ID | `%s` |\n", provider.DefaultSmallModelID) + + if len(provider.Models) > 0 { + fmt.Fprintf(&buf, "\n## Available Models\n\n") + + models := provider.Models + sortParam := r.URL.Query().Get("sort") + + switch sortParam { + case "output": + // Sort by output cost (ascending) + sort.Slice(models, func(i, j int) bool { + return models[i].CostPer1MOut < models[j].CostPer1MOut + }) + case "context": + // Sort by context window (descending) + sort.Slice(models, func(i, j int) bool { + return models[i].ContextWindow > models[j].ContextWindow + }) + default: + // Default sort by input cost (ascending) + sort.Slice(models, func(i, j int) bool { + return models[i].CostPer1MIn < models[j].CostPer1MIn + }) + } + + fmt.Fprintf(&buf, "| Name | ID | Context Window | Input Cost ($/M) | Output Cost ($/M) | Reasoning |\n") + fmt.Fprintf(&buf, "|----------|------|----------------|------------------|------------------|----------|\n") + for _, model := range models { + fmt.Fprintf(&buf, "| %s | `%s` | %d | %.6f | %.6f | %s |\n", model.Name, model.ID, model.ContextWindow, model.CostPer1MIn, model.CostPer1MOut, yesNo(model.CanReason)) + } + } + + if _, err := w.Write([]byte(buf.String())); err != nil { + log.Printf("Error writing markdown response: %v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + func providersHandlerDeprecated(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") @@ -81,6 +188,7 @@ func providersHandlerDeprecated(w http.ResponseWriter, _ *http.Request) { func main() { mux := http.NewServeMux() mux.HandleFunc("/v2/providers", providersHandler) + mux.HandleFunc("/v2/providers/", providersSpecificHandler) mux.HandleFunc("/providers", providersHandlerDeprecated) mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK)