Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@ package main

import (
"encoding/json"
"fmt"
"log"
"net/http"
"sort"
"strings"
"time"

"charm.land/catwalk/internal/providers"
"charm.land/catwalk/pkg/catwalk"
"github.com/charmbracelet/x/etag"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
Expand Down Expand Up @@ -69,6 +73,109 @@ func providersHandler(w http.ResponseWriter, r *http.Request) {
}
}

func providersSpecificHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}

path := strings.TrimPrefix(r.URL.Path, "/v2/providers/")
providerID := strings.TrimSuffix(path, ".md")
providerID = strings.TrimSuffix(providerID, "/")

if providerID == "" {
http.Error(w, "Provider ID is required", http.StatusBadRequest)
return
}

isMarkdown := strings.HasSuffix(r.URL.Path, ".md")

allProviders := providers.GetAll()

var foundProvider *catwalk.Provider
for _, provider := range allProviders {
if string(provider.ID) == providerID {
foundProvider = &provider
break
}
}

if foundProvider == nil {
http.Error(w, "Provider not found", http.StatusNotFound)
return
}

if isMarkdown {
// Return as markdown table
w.Header().Set("Content-Type", "text/markdown")
renderProviderMarkdown(w, *foundProvider, r)
} else {
// Return as JSON
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(foundProvider); err != nil {
log.Printf("Error encoding provider response: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
}

func yesNo(b bool) string {
if b {
return "yes"
}
return "no"
}

func renderProviderMarkdown(w http.ResponseWriter, provider catwalk.Provider, r *http.Request) {
var buf strings.Builder

fmt.Fprintf(&buf, "# Provider: %s\n\n", provider.Name)
fmt.Fprintf(&buf, "| Field | Value |\n")
fmt.Fprintf(&buf, "|-------|-------|\n")
fmt.Fprintf(&buf, "| ID | `%s` |\n", provider.ID)
fmt.Fprintf(&buf, "| Type | `%s` |\n", provider.Type)
fmt.Fprintf(&buf, "| API Endpoint | `%s` |\n", provider.APIEndpoint)
fmt.Fprintf(&buf, "| Default Large Model ID | `%s` |\n", provider.DefaultLargeModelID)
fmt.Fprintf(&buf, "| Default Small Model ID | `%s` |\n", provider.DefaultSmallModelID)

if len(provider.Models) > 0 {
fmt.Fprintf(&buf, "\n## Available Models\n\n")

models := provider.Models
sortParam := r.URL.Query().Get("sort")

switch sortParam {
case "output":
// Sort by output cost (ascending)
sort.Slice(models, func(i, j int) bool {
return models[i].CostPer1MOut < models[j].CostPer1MOut
})
case "context":
// Sort by context window (descending)
sort.Slice(models, func(i, j int) bool {
return models[i].ContextWindow > models[j].ContextWindow
})
default:
// Default sort by input cost (ascending)
sort.Slice(models, func(i, j int) bool {
return models[i].CostPer1MIn < models[j].CostPer1MIn
})
}

fmt.Fprintf(&buf, "| Name | ID | Context Window | Input Cost ($/M) | Output Cost ($/M) | Reasoning |\n")
fmt.Fprintf(&buf, "|----------|------|----------------|------------------|------------------|----------|\n")
for _, model := range models {
fmt.Fprintf(&buf, "| %s | `%s` | %d | %.6f | %.6f | %s |\n", model.Name, model.ID, model.ContextWindow, model.CostPer1MIn, model.CostPer1MOut, yesNo(model.CanReason))
}
}

if _, err := w.Write([]byte(buf.String())); err != nil {
log.Printf("Error writing markdown response: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}

func providersHandlerDeprecated(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")

Expand All @@ -81,6 +188,7 @@ func providersHandlerDeprecated(w http.ResponseWriter, _ *http.Request) {
func main() {
mux := http.NewServeMux()
mux.HandleFunc("/v2/providers", providersHandler)
mux.HandleFunc("/v2/providers/", providersSpecificHandler)
mux.HandleFunc("/providers", providersHandlerDeprecated)
mux.HandleFunc("/healthz", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
Expand Down
Loading