@@ -29,10 +29,15 @@ import com.tchat.data.database.entity.AssistantEntity
2929import com.tchat.data.model.Assistant
3030import com.tchat.data.model.LocalToolOption
3131import com.tchat.data.repository.impl.ChatRepositoryImpl
32+ import com.tchat.data.repository.impl.KnowledgeRepositoryImpl
33+ import com.tchat.data.service.KnowledgeService
34+ import com.tchat.data.tool.KnowledgeSearchTool
3235import com.tchat.data.tool.LocalTools
36+ import com.tchat.data.tool.Tool
3337import com.tchat.feature.chat.ChatScreen
3438import com.tchat.feature.chat.ChatViewModel
3539import com.tchat.network.provider.AIProviderFactory
40+ import com.tchat.network.provider.EmbeddingProviderFactory
3641import com.tchat.wanxiaot.settings.AIProviderType
3742import com.tchat.wanxiaot.settings.SettingsManager
3843import com.tchat.wanxiaot.ui.DrawerContent
@@ -133,6 +138,18 @@ fun MainScreen(
133138 val messageDao = database.messageDao()
134139 val assistantDao = database.assistantDao()
135140
141+ // 知识库相关
142+ val knowledgeRepository = remember(database) {
143+ KnowledgeRepositoryImpl (
144+ database.knowledgeBaseDao(),
145+ database.knowledgeItemDao(),
146+ database.knowledgeChunkDao()
147+ )
148+ }
149+ val knowledgeService = remember(knowledgeRepository) {
150+ KnowledgeService (knowledgeRepository)
151+ }
152+
136153 // null表示新对话(懒创建模式)
137154 var currentChatId by remember { mutableStateOf<String ?>(null ) }
138155 var chatList by remember { mutableStateOf(emptyList< com.tchat.data.model.Chat > ()) }
@@ -349,6 +366,33 @@ fun MainScreen(
349366 LocalTools (context)
350367 }
351368
369+ // 计算知识库搜索工具(当 currentAssistant 变化时重新计算)
370+ val knowledgeTools = remember(currentAssistant?.knowledgeBaseId, knowledgeService, knowledgeRepository) {
371+ currentAssistant?.knowledgeBaseId?.let { kbId ->
372+ println (" === 知识库工具已启用 ===" )
373+ println (" 知识库ID: $kbId " )
374+ println (" 助手: ${currentAssistant?.name} " )
375+ listOf (
376+ KnowledgeSearchTool .create(
377+ knowledgeService = knowledgeService,
378+ repository = knowledgeRepository,
379+ getEmbeddingProvider = { knowledgeBaseId ->
380+ // 从知识库配置获取 Embedding Provider
381+ getEmbeddingProviderForKnowledgeBase(
382+ knowledgeBaseId = knowledgeBaseId,
383+ knowledgeRepository = knowledgeRepository,
384+ settingsManager = settingsManager
385+ )
386+ },
387+ knowledgeBaseId = kbId
388+ )
389+ )
390+ } ? : run {
391+ println (" === 未绑定知识库 ===" )
392+ emptyList()
393+ }
394+ }
395+
352396 ChatScreen (
353397 viewModel = viewModel,
354398 chatId = currentChatId,
@@ -372,6 +416,8 @@ fun MainScreen(
372416 getToolsForOptions = { options ->
373417 localTools.getToolsForOptions(options)
374418 },
419+ // 知识库搜索工具作为额外工具传递
420+ extraTools = knowledgeTools,
375421 systemPrompt = currentAssistant?.systemPrompt
376422 )
377423 }
@@ -435,7 +481,46 @@ private fun entityToAssistant(entity: AssistantEntity): Assistant {
435481 contextMessageSize = entity.contextMessageSize,
436482 streamOutput = entity.streamOutput,
437483 localTools = toolOptions,
484+ knowledgeBaseId = entity.knowledgeBaseId,
438485 createdAt = entity.createdAt,
439486 updatedAt = entity.updatedAt
440487 )
441488}
489+
490+ /* *
491+ * 根据知识库配置获取对应的 Embedding Provider
492+ * 知识库使用自己配置的 Embedding 服务商,与对话模型提供商独立
493+ */
494+ private fun getEmbeddingProviderForKnowledgeBase (
495+ knowledgeBaseId : String ,
496+ knowledgeRepository : KnowledgeRepositoryImpl ,
497+ settingsManager : SettingsManager
498+ ): com.tchat.network.provider.EmbeddingProvider ? {
499+ return try {
500+ // 获取知识库配置(同步方式,因为我们在工具执行时需要)
501+ val base = kotlinx.coroutines.runBlocking {
502+ knowledgeRepository.getBaseById(knowledgeBaseId)
503+ } ? : return null
504+
505+ // 获取设置中的服务商配置
506+ val settings = settingsManager.settings.value
507+ val providerConfig = settings.providers.find { it.id == base.embeddingProviderId }
508+ ? : return null
509+
510+ // 根据服务商类型创建 Embedding Provider
511+ val providerType = when (providerConfig.providerType) {
512+ AIProviderType .OPENAI -> EmbeddingProviderFactory .EmbeddingProviderType .OPENAI
513+ AIProviderType .GEMINI -> EmbeddingProviderFactory .EmbeddingProviderType .GEMINI
514+ else -> return null
515+ }
516+
517+ EmbeddingProviderFactory .create(
518+ type = providerType,
519+ apiKey = providerConfig.apiKey,
520+ baseUrl = providerConfig.endpoint.ifBlank { null }
521+ )
522+ } catch (e: Exception ) {
523+ e.printStackTrace()
524+ null
525+ }
526+ }
0 commit comments