|
6 | 6 | "strings" |
7 | 7 |
|
8 | 8 | "github.com/QuantumNous/new-api/common" |
| 9 | + "github.com/QuantumNous/new-api/setting/operation_setting" |
9 | 10 | "github.com/bytedance/gopkg/util/gopool" |
10 | 11 | "gorm.io/gorm" |
11 | 12 | ) |
@@ -63,12 +64,103 @@ func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) { |
63 | 64 | return tokens, err |
64 | 65 | } |
65 | 66 |
|
66 | | -func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token, err error) { |
| 67 | +// sanitizeLikePattern 校验并清洗用户输入的 LIKE 搜索模式。 |
| 68 | +// 规则: |
| 69 | +// 1. 转义 _ 和 \(不允许 _ 作通配符) |
| 70 | +// 2. 连续的 % 合并为单个 % |
| 71 | +// 3. 最多允许 2 个 % |
| 72 | +// 4. 含 % 时(模糊搜索),去掉 % 后关键词长度必须 >= 2 |
| 73 | +// 5. 不含 % 时按精确匹配 |
| 74 | +func sanitizeLikePattern(input string) (string, error) { |
| 75 | + // 1. 转义 \ 和 _ |
| 76 | + input = strings.ReplaceAll(input, `\`, `\\`) |
| 77 | + input = strings.ReplaceAll(input, `_`, `\_`) |
| 78 | + |
| 79 | + // 2. 连续的 % 直接拒绝 |
| 80 | + if strings.Contains(input, "%%") { |
| 81 | + return "", errors.New("搜索模式中不允许包含连续的 % 通配符") |
| 82 | + } |
| 83 | + |
| 84 | + // 3. 统计 % 数量,不得超过 2 |
| 85 | + count := strings.Count(input, "%") |
| 86 | + if count > 2 { |
| 87 | + return "", errors.New("搜索模式中最多允许包含 2 个 % 通配符") |
| 88 | + } |
| 89 | + |
| 90 | + // 4. 含 % 时,去掉 % 后关键词长度必须 >= 2 |
| 91 | + if count > 0 { |
| 92 | + stripped := strings.ReplaceAll(input, "%", "") |
| 93 | + if len(stripped) < 2 { |
| 94 | + return "", errors.New("使用模糊搜索时,关键词长度至少为 2 个字符") |
| 95 | + } |
| 96 | + return input, nil |
| 97 | + } |
| 98 | + |
| 99 | + // 5. 无 % 时,精确全匹配 |
| 100 | + return input, nil |
| 101 | +} |
| 102 | + |
| 103 | +const searchHardLimit = 100 |
| 104 | + |
| 105 | +func SearchUserTokens(userId int, keyword string, token string, offset int, limit int) (tokens []*Token, total int64, err error) { |
| 106 | + // model 层强制截断 |
| 107 | + if limit <= 0 || limit > searchHardLimit { |
| 108 | + limit = searchHardLimit |
| 109 | + } |
| 110 | + if offset < 0 { |
| 111 | + offset = 0 |
| 112 | + } |
| 113 | + |
67 | 114 | if token != "" { |
68 | 115 | token = strings.Trim(token, "sk-") |
69 | 116 | } |
70 | | - err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(commonKeyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error |
71 | | - return tokens, err |
| 117 | + |
| 118 | + // 超量用户(令牌数超过上限)只允许精确搜索,禁止模糊搜索 |
| 119 | + maxTokens := operation_setting.GetMaxUserTokens() |
| 120 | + hasFuzzy := strings.Contains(keyword, "%") || strings.Contains(token, "%") |
| 121 | + if hasFuzzy { |
| 122 | + count, err := CountUserTokens(userId) |
| 123 | + if err != nil { |
| 124 | + common.SysLog("failed to count user tokens: " + err.Error()) |
| 125 | + return nil, 0, errors.New("获取令牌数量失败") |
| 126 | + } |
| 127 | + if int(count) > maxTokens { |
| 128 | + return nil, 0, errors.New("令牌数量超过上限,仅允许精确搜索,请勿使用 % 通配符") |
| 129 | + } |
| 130 | + } |
| 131 | + |
| 132 | + baseQuery := DB.Model(&Token{}).Where("user_id = ?", userId) |
| 133 | + |
| 134 | + // 非空才加 LIKE 条件,空则跳过(不过滤该字段) |
| 135 | + if keyword != "" { |
| 136 | + keywordPattern, err := sanitizeLikePattern(keyword) |
| 137 | + if err != nil { |
| 138 | + return nil, 0, err |
| 139 | + } |
| 140 | + baseQuery = baseQuery.Where("name LIKE ? ESCAPE '\\'", keywordPattern) |
| 141 | + } |
| 142 | + if token != "" { |
| 143 | + tokenPattern, err := sanitizeLikePattern(token) |
| 144 | + if err != nil { |
| 145 | + return nil, 0, err |
| 146 | + } |
| 147 | + baseQuery = baseQuery.Where(commonKeyCol+" LIKE ? ESCAPE '\\'", tokenPattern) |
| 148 | + } |
| 149 | + |
| 150 | + // 先查匹配总数(用于分页,受 maxTokens 上限保护,避免全表 COUNT) |
| 151 | + err = baseQuery.Limit(maxTokens).Count(&total).Error |
| 152 | + if err != nil { |
| 153 | + common.SysError("failed to count search tokens: " + err.Error()) |
| 154 | + return nil, 0, errors.New("搜索令牌失败") |
| 155 | + } |
| 156 | + |
| 157 | + // 再分页查数据 |
| 158 | + err = baseQuery.Order("id desc").Offset(offset).Limit(limit).Find(&tokens).Error |
| 159 | + if err != nil { |
| 160 | + common.SysError("failed to search tokens: " + err.Error()) |
| 161 | + return nil, 0, errors.New("搜索令牌失败") |
| 162 | + } |
| 163 | + return tokens, total, nil |
72 | 164 | } |
73 | 165 |
|
74 | 166 | func ValidateUserToken(key string) (token *Token, err error) { |
|
0 commit comments