Skip to content

Commit a840c83

Browse files
authored
Merge pull request #55 from chaitin/feat-stream-check
feat: 支持流式check
2 parents 5cda2d9 + 60f1b07 commit a840c83

File tree

1 file changed

+43
-6
lines changed

1 file changed

+43
-6
lines changed

usecase/modelkit.go

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"context"
66
"encoding/json"
77
"fmt"
8+
"io"
89
"log"
910
"net/http"
1011
"net/url"
@@ -211,11 +212,11 @@ func CheckModel(ctx context.Context, req *domain.CheckModelReq) (*domain.CheckMo
211212
return checkResp, nil
212213
}
213214

214-
if resp.Content == "" {
215+
if resp == "" {
215216
checkResp.Error = "生成内容失败"
216217
return checkResp, nil
217218
}
218-
checkResp.Content = resp.Content
219+
checkResp.Content = resp
219220
return checkResp, nil
220221
}
221222

@@ -286,7 +287,7 @@ func GetChatModel(ctx context.Context, model *domain.ModelMetadata) (model.BaseC
286287
}
287288

288289
// 以下是辅助函数,用于处理模型列表和检查相关的功能
289-
func getChatModelGenerateChat(ctx context.Context, provider consts.ModelProvider, modelType consts.ModelType, baseURL string, req *domain.CheckModelReq) (*schema.Message, error) {
290+
func getChatModelGenerateChat(ctx context.Context, provider consts.ModelProvider, modelType consts.ModelType, baseURL string, req *domain.CheckModelReq) (string, error) {
290291
chatModel, err := GetChatModel(ctx, &domain.ModelMetadata{
291292
Provider: provider,
292293
ModelName: req.Model,
@@ -297,13 +298,49 @@ func getChatModelGenerateChat(ctx context.Context, provider consts.ModelProvider
297298
ModelType: modelType,
298299
})
299300
if err != nil {
300-
return nil, err
301+
return "", err
301302
}
302303

303-
return chatModel.Generate(ctx, []*schema.Message{
304+
genResp, err := chatModel.Generate(ctx, []*schema.Message{
305+
schema.SystemMessage("You are a helpful assistant."),
306+
schema.UserMessage("hi"),
307+
})
308+
// 非流式生成失败,尝试流式生成
309+
if err != nil || genResp.Content == "" {
310+
log.Printf("Generate chat failed, err: %v", err)
311+
streamRes, streamErr := streamCheck(ctx, &chatModel)
312+
if streamErr != nil {
313+
log.Printf("Stream chat failed, err: %v", streamErr)
314+
return "", err
315+
}
316+
return streamRes, nil
317+
}
318+
319+
return genResp.Content, nil
320+
}
321+
322+
func streamCheck(ctx context.Context, chatModel *model.BaseChatModel) (string, error) {
323+
var res string
324+
streamResult, err := (*chatModel).Stream(ctx, []*schema.Message{
304325
schema.SystemMessage("You are a helpful assistant."),
305326
schema.UserMessage("hi"),
306327
})
328+
if err != nil {
329+
return "", err
330+
}
331+
332+
for {
333+
chunk, err := streamResult.Recv()
334+
if err == io.EOF {
335+
break
336+
}
337+
if err != nil {
338+
// 错误处理
339+
}
340+
// 响应片段处理
341+
res += chunk.Content
342+
}
343+
return res, nil
307344
}
308345

309346
// baseURL添加/v1
@@ -374,7 +411,7 @@ func reqModelListApi[T domain.ModelResponseParser](req *domain.ModelListReq, htt
374411
}
375412

376413
func generateBaseURLFixSuggestion(errContent string, baseURL string, provider consts.ModelProvider) string {
377-
var is404, isLocal, hasPath , isOther bool
414+
var is404, isLocal, hasPath, isOther bool
378415
if strings.Contains(errContent, "404") || strings.Contains(errContent, "connection refused") {
379416
is404 = true
380417
}

0 commit comments

Comments
 (0)