-
Notifications
You must be signed in to change notification settings - Fork 303
Expand file tree
/
Copy pathmain.go
More file actions
164 lines (136 loc) · 4.46 KB
/
main.go
File metadata and controls
164 lines (136 loc) · 4.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
package main
import (
"context"
"cursor2api-go/config"
"cursor2api-go/handlers"
"cursor2api-go/middleware"
"cursor2api-go/models"
"fmt"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
)
func main() {
// 加载配置
cfg, err := config.LoadConfig()
if err != nil {
logrus.Fatalf("Failed to load config: %v", err)
}
// 设置日志级别和 GIN 模式(必须在创建路由器之前设置)
if cfg.Debug {
logrus.SetLevel(logrus.DebugLevel)
gin.SetMode(gin.DebugMode)
} else {
logrus.SetLevel(logrus.InfoLevel)
gin.SetMode(gin.ReleaseMode)
}
// 禁用 Gin 的调试信息输出
gin.DisableConsoleColor()
// 创建路由器(使用 gin.New() 而不是 gin.Default() 以避免默认日志)
router := gin.New()
// 添加中间件
router.Use(gin.Recovery())
router.Use(middleware.CORS())
router.Use(middleware.ErrorHandler())
// 只在 Debug 模式下启用 GIN 的日志
if cfg.Debug {
router.Use(gin.Logger())
}
// 创建处理器
handler := handlers.NewHandler(cfg)
// 注册路由
setupRoutes(router, handler)
// 创建HTTP服务器
server := &http.Server{
Addr: fmt.Sprintf(":%d", cfg.Port),
Handler: router,
}
// 打印启动信息
printStartupBanner(cfg)
// 启动服务器的goroutine
go func() {
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logrus.Fatalf("Failed to start server: %v", err)
}
}()
// 等待中断信号以优雅关闭服务器
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
logrus.Info("Shutting down server...")
// 给服务器5秒时间完成处理正在进行的请求
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := server.Shutdown(ctx); err != nil {
logrus.Fatalf("Server forced to shutdown: %v", err)
}
logrus.Info("Server exited")
}
func setupRoutes(router *gin.Engine, handler *handlers.Handler) {
// 健康检查
router.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"status": "ok",
"time": time.Now().Unix(),
})
})
// API文档页面
router.GET("/", handler.ServeDocs)
// API v1路由组
v1 := router.Group("/v1")
{
// 模型列表
v1.GET("/models", handler.ListModels) // 模型列表不需要鉴权
// 聊天完成
v1.POST("/chat/completions", middleware.AuthRequired(), handler.ChatCompletions)
}
// 静态文件服务(如果需要)
router.Static("/static", "./static")
}
// printStartupBanner 打印启动横幅
func printStartupBanner(cfg *config.Config) {
banner := `
╔══════════════════════════════════════════════════════════════╗
║ Cursor2API Server ║
╚══════════════════════════════════════════════════════════════╝
`
fmt.Println(banner)
fmt.Printf("🚀 服务地址: http://localhost:%d\n", cfg.Port)
fmt.Printf("📚 API 文档: http://localhost:%d/\n", cfg.Port)
fmt.Printf("💊 健康检查: http://localhost:%d/health\n", cfg.Port)
fmt.Printf("🔑 API 密钥: %s\n", maskAPIKey(cfg.APIKey))
modelList := cfg.GetModels()
fmt.Printf("\n🤖 支持模型 (%d 个):\n", len(modelList))
// 按类别分组显示模型
providers := make(map[string][]string)
for _, modelID := range modelList {
if config, exists := models.GetModelConfig(modelID); exists {
providers[config.Provider] = append(providers[config.Provider], modelID)
} else {
providers["Other"] = append(providers["Other"], modelID)
}
}
// 按Provider排序并显示
for _, provider := range []string{"Anthropic", "OpenAI", "Google", "Other"} {
if models, ok := providers[provider]; ok && len(models) > 0 {
fmt.Printf(" %s: %s\n", provider, strings.Join(models, ", "))
}
}
if cfg.Debug {
fmt.Println("\n🐛 调试模式: 已启用")
}
fmt.Println("\n✨ 服务已启动,按 Ctrl+C 停止")
fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
}
// maskAPIKey 掩码 API 密钥,只显示前 4 位
func maskAPIKey(key string) string {
if len(key) <= 4 {
return "****"
}
return key[:4] + "****"
}