Skip to content

Commit 053699f

Browse files
authored
Merge commit from fork
fix: harden token search with pagination, rate limiting and input validation
2 parents f3d6e99 + 3e1be18 commit 053699f

11 files changed

Lines changed: 282 additions & 20 deletions

File tree

common/constants.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,10 @@ var (
175175

176176
DownloadRateLimitNum = 10
177177
DownloadRateLimitDuration int64 = 60
178+
179+
// Per-user search rate limit (applies after authentication, keyed by user ID)
180+
SearchRateLimitNum = 10
181+
SearchRateLimitDuration int64 = 60
178182
)
179183

180184
var RateLimitKeyExpirationDuration = 20 * time.Minute

common/utils.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ func Interface2String(inter interface{}) string {
192192
case int:
193193
return fmt.Sprintf("%d", inter.(int))
194194
case float64:
195-
return fmt.Sprintf("%f", inter.(float64))
195+
return strconv.FormatFloat(inter.(float64), 'f', -1, 64)
196196
case bool:
197197
if inter.(bool) {
198198
return "true"

controller/token.go

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"github.com/QuantumNous/new-api/common"
99
"github.com/QuantumNous/new-api/i18n"
1010
"github.com/QuantumNous/new-api/model"
11+
"github.com/QuantumNous/new-api/setting/operation_setting"
1112

1213
"github.com/gin-gonic/gin"
1314
)
@@ -31,16 +32,17 @@ func SearchTokens(c *gin.Context) {
3132
userId := c.GetInt("id")
3233
keyword := c.Query("keyword")
3334
token := c.Query("token")
34-
tokens, err := model.SearchUserTokens(userId, keyword, token)
35+
36+
pageInfo := common.GetPageQuery(c)
37+
38+
tokens, total, err := model.SearchUserTokens(userId, keyword, token, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
3539
if err != nil {
3640
common.ApiError(c, err)
3741
return
3842
}
39-
c.JSON(http.StatusOK, gin.H{
40-
"success": true,
41-
"message": "",
42-
"data": tokens,
43-
})
43+
pageInfo.SetTotal(int(total))
44+
pageInfo.SetItems(tokens)
45+
common.ApiSuccess(c, pageInfo)
4446
return
4547
}
4648

@@ -157,6 +159,20 @@ func AddToken(c *gin.Context) {
157159
return
158160
}
159161
}
162+
// 检查用户令牌数量是否已达上限
163+
maxTokens := operation_setting.GetMaxUserTokens()
164+
count, err := model.CountUserTokens(c.GetInt("id"))
165+
if err != nil {
166+
common.ApiError(c, err)
167+
return
168+
}
169+
if int(count) >= maxTokens {
170+
c.JSON(http.StatusOK, gin.H{
171+
"success": false,
172+
"message": fmt.Sprintf("已达到最大令牌数量限制 (%d)", maxTokens),
173+
})
174+
return
175+
}
160176
key, err := common.GenerateKey()
161177
if err != nil {
162178
common.ApiErrorI18n(c, i18n.MsgTokenGenerateFailed)

middleware/rate-limit.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,88 @@ func DownloadRateLimit() func(c *gin.Context) {
115115
func UploadRateLimit() func(c *gin.Context) {
116116
return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP")
117117
}
118+
119+
// userRateLimitFactory creates a rate limiter keyed by authenticated user ID
120+
// instead of client IP, making it resistant to proxy rotation attacks.
121+
// Must be used AFTER authentication middleware (UserAuth).
122+
func userRateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gin.Context) {
123+
if common.RedisEnabled {
124+
return func(c *gin.Context) {
125+
userId := c.GetInt("id")
126+
if userId == 0 {
127+
c.Status(http.StatusUnauthorized)
128+
c.Abort()
129+
return
130+
}
131+
key := fmt.Sprintf("rateLimit:%s:user:%d", mark, userId)
132+
userRedisRateLimiter(c, maxRequestNum, duration, key)
133+
}
134+
}
135+
// It's safe to call multi times.
136+
inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
137+
return func(c *gin.Context) {
138+
userId := c.GetInt("id")
139+
if userId == 0 {
140+
c.Status(http.StatusUnauthorized)
141+
c.Abort()
142+
return
143+
}
144+
key := fmt.Sprintf("%s:user:%d", mark, userId)
145+
if !inMemoryRateLimiter.Request(key, maxRequestNum, duration) {
146+
c.Status(http.StatusTooManyRequests)
147+
c.Abort()
148+
return
149+
}
150+
}
151+
}
152+
153+
// userRedisRateLimiter is like redisRateLimiter but accepts a pre-built key
154+
// (to support user-ID-based keys).
155+
func userRedisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, key string) {
156+
ctx := context.Background()
157+
rdb := common.RDB
158+
listLength, err := rdb.LLen(ctx, key).Result()
159+
if err != nil {
160+
fmt.Println(err.Error())
161+
c.Status(http.StatusInternalServerError)
162+
c.Abort()
163+
return
164+
}
165+
if listLength < int64(maxRequestNum) {
166+
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
167+
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
168+
} else {
169+
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
170+
oldTime, err := time.Parse(timeFormat, oldTimeStr)
171+
if err != nil {
172+
fmt.Println(err)
173+
c.Status(http.StatusInternalServerError)
174+
c.Abort()
175+
return
176+
}
177+
nowTimeStr := time.Now().Format(timeFormat)
178+
nowTime, err := time.Parse(timeFormat, nowTimeStr)
179+
if err != nil {
180+
fmt.Println(err)
181+
c.Status(http.StatusInternalServerError)
182+
c.Abort()
183+
return
184+
}
185+
if int64(nowTime.Sub(oldTime).Seconds()) < duration {
186+
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
187+
c.Status(http.StatusTooManyRequests)
188+
c.Abort()
189+
return
190+
} else {
191+
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
192+
rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1))
193+
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
194+
}
195+
}
196+
}
197+
198+
// SearchRateLimit returns a per-user rate limiter for search endpoints.
199+
// 10 requests per 60 seconds per user (by user ID, not IP).
200+
func SearchRateLimit() func(c *gin.Context) {
201+
return userRateLimitFactory(common.SearchRateLimitNum, common.SearchRateLimitDuration, "SR")
202+
}

model/token.go

Lines changed: 95 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"strings"
77

88
"github.com/QuantumNous/new-api/common"
9+
"github.com/QuantumNous/new-api/setting/operation_setting"
910
"github.com/bytedance/gopkg/util/gopool"
1011
"gorm.io/gorm"
1112
)
@@ -63,12 +64,103 @@ func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
6364
return tokens, err
6465
}
6566

66-
func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token, err error) {
67+
// sanitizeLikePattern 校验并清洗用户输入的 LIKE 搜索模式。
68+
// 规则:
69+
// 1. 转义 _ 和 \(不允许 _ 作通配符)
70+
// 2. 连续的 % 合并为单个 %
71+
// 3. 最多允许 2 个 %
72+
// 4. 含 % 时(模糊搜索),去掉 % 后关键词长度必须 >= 2
73+
// 5. 不含 % 时按精确匹配
74+
func sanitizeLikePattern(input string) (string, error) {
75+
// 1. 转义 \ 和 _
76+
input = strings.ReplaceAll(input, `\`, `\\`)
77+
input = strings.ReplaceAll(input, `_`, `\_`)
78+
79+
// 2. 连续的 % 直接拒绝
80+
if strings.Contains(input, "%%") {
81+
return "", errors.New("搜索模式中不允许包含连续的 % 通配符")
82+
}
83+
84+
// 3. 统计 % 数量,不得超过 2
85+
count := strings.Count(input, "%")
86+
if count > 2 {
87+
return "", errors.New("搜索模式中最多允许包含 2 个 % 通配符")
88+
}
89+
90+
// 4. 含 % 时,去掉 % 后关键词长度必须 >= 2
91+
if count > 0 {
92+
stripped := strings.ReplaceAll(input, "%", "")
93+
if len(stripped) < 2 {
94+
return "", errors.New("使用模糊搜索时,关键词长度至少为 2 个字符")
95+
}
96+
return input, nil
97+
}
98+
99+
// 5. 无 % 时,精确全匹配
100+
return input, nil
101+
}
102+
103+
const searchHardLimit = 100
104+
105+
func SearchUserTokens(userId int, keyword string, token string, offset int, limit int) (tokens []*Token, total int64, err error) {
106+
// model 层强制截断
107+
if limit <= 0 || limit > searchHardLimit {
108+
limit = searchHardLimit
109+
}
110+
if offset < 0 {
111+
offset = 0
112+
}
113+
67114
if token != "" {
68115
token = strings.Trim(token, "sk-")
69116
}
70-
err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(commonKeyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
71-
return tokens, err
117+
118+
// 超量用户(令牌数超过上限)只允许精确搜索,禁止模糊搜索
119+
maxTokens := operation_setting.GetMaxUserTokens()
120+
hasFuzzy := strings.Contains(keyword, "%") || strings.Contains(token, "%")
121+
if hasFuzzy {
122+
count, err := CountUserTokens(userId)
123+
if err != nil {
124+
common.SysLog("failed to count user tokens: " + err.Error())
125+
return nil, 0, errors.New("获取令牌数量失败")
126+
}
127+
if int(count) > maxTokens {
128+
return nil, 0, errors.New("令牌数量超过上限,仅允许精确搜索,请勿使用 % 通配符")
129+
}
130+
}
131+
132+
baseQuery := DB.Model(&Token{}).Where("user_id = ?", userId)
133+
134+
// 非空才加 LIKE 条件,空则跳过(不过滤该字段)
135+
if keyword != "" {
136+
keywordPattern, err := sanitizeLikePattern(keyword)
137+
if err != nil {
138+
return nil, 0, err
139+
}
140+
baseQuery = baseQuery.Where("name LIKE ? ESCAPE '\\'", keywordPattern)
141+
}
142+
if token != "" {
143+
tokenPattern, err := sanitizeLikePattern(token)
144+
if err != nil {
145+
return nil, 0, err
146+
}
147+
baseQuery = baseQuery.Where(commonKeyCol+" LIKE ? ESCAPE '\\'", tokenPattern)
148+
}
149+
150+
// 先查匹配总数(用于分页,受 maxTokens 上限保护,避免全表 COUNT)
151+
err = baseQuery.Limit(maxTokens).Count(&total).Error
152+
if err != nil {
153+
common.SysError("failed to count search tokens: " + err.Error())
154+
return nil, 0, errors.New("搜索令牌失败")
155+
}
156+
157+
// 再分页查数据
158+
err = baseQuery.Order("id desc").Offset(offset).Limit(limit).Find(&tokens).Error
159+
if err != nil {
160+
common.SysError("failed to search tokens: " + err.Error())
161+
return nil, 0, errors.New("搜索令牌失败")
162+
}
163+
return tokens, total, nil
72164
}
73165

74166
func ValidateUserToken(key string) (token *Token, err error) {

router/api-router.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ func SetApiRouter(router *gin.Engine) {
238238
tokenRoute.Use(middleware.UserAuth())
239239
{
240240
tokenRoute.GET("/", controller.GetAllTokens)
241-
tokenRoute.GET("/search", controller.SearchTokens)
241+
tokenRoute.GET("/search", middleware.SearchRateLimit(), controller.SearchTokens)
242242
tokenRoute.GET("/:id", controller.GetToken)
243243
tokenRoute.POST("/", controller.AddToken)
244244
tokenRoute.PUT("/", controller.UpdateToken)

setting/config/config.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,13 +212,23 @@ func updateConfigFromMap(config interface{}, configMap map[string]string) error
212212
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
213213
intValue, err := strconv.ParseInt(strValue, 10, 64)
214214
if err != nil {
215-
continue
215+
// 兼容 float 格式的字符串(如 "2.000000")
216+
floatValue, fErr := strconv.ParseFloat(strValue, 64)
217+
if fErr != nil {
218+
continue
219+
}
220+
intValue = int64(floatValue)
216221
}
217222
field.SetInt(intValue)
218223
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
219224
uintValue, err := strconv.ParseUint(strValue, 10, 64)
220225
if err != nil {
221-
continue
226+
// 兼容 float 格式的字符串
227+
floatValue, fErr := strconv.ParseFloat(strValue, 64)
228+
if fErr != nil || floatValue < 0 {
229+
continue
230+
}
231+
uintValue = uint64(floatValue)
222232
}
223233
field.SetUint(uintValue)
224234
case reflect.Float32, reflect.Float64:
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package operation_setting
2+
3+
import "github.com/QuantumNous/new-api/setting/config"
4+
5+
// TokenSetting 令牌相关配置
6+
type TokenSetting struct {
7+
MaxUserTokens int `json:"max_user_tokens"` // 每用户最大令牌数量
8+
}
9+
10+
// 默认配置
11+
var tokenSetting = TokenSetting{
12+
MaxUserTokens: 1000, // 默认每用户最多 1000 个令牌
13+
}
14+
15+
func init() {
16+
// 注册到全局配置管理器
17+
config.GlobalConfig.Register("token_setting", &tokenSetting)
18+
}
19+
20+
// GetTokenSetting 获取令牌配置
21+
func GetTokenSetting() *TokenSetting {
22+
return &tokenSetting
23+
}
24+
25+
// GetMaxUserTokens 获取每用户最大令牌数量
26+
func GetMaxUserTokens() int {
27+
return GetTokenSetting().MaxUserTokens
28+
}

web/src/components/settings/OperationSetting.jsx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ const OperationSetting = () => {
7878
'checkin_setting.enabled': false,
7979
'checkin_setting.min_quota': 1000,
8080
'checkin_setting.max_quota': 10000,
81+
82+
/* 令牌设置 */
83+
'token_setting.max_user_tokens': 1000,
8184
});
8285

8386
let [loading, setLoading] = useState(false);

0 commit comments

Comments
 (0)