diff --git a/backend/consts/model.py b/backend/consts/model.py index 986a5ce5..2c4c2a2c 100644 --- a/backend/consts/model.py +++ b/backend/consts/model.py @@ -211,6 +211,8 @@ class AgentInfoRequest(BaseModel): constraint_prompt: Optional[str] = None few_shots_prompt: Optional[str] = None enabled: Optional[bool] = None + business_logic_model_name: Optional[str] = None + business_logic_model_id: Optional[int] = None class AgentIDRequest(BaseModel): diff --git a/backend/data_process/tasks.py b/backend/data_process/tasks.py index 0cf85452..0cd43108 100644 --- a/backend/data_process/tasks.py +++ b/backend/data_process/tasks.py @@ -201,8 +201,8 @@ def process( f"[{self.request.id}] PROCESS TASK: File size: {file_size_mb:.2f}MB") # The unified actor call, mapping 'file' source_type to 'local' destination - # Submit Ray work and do not block here - logger.debug( + # Submit Ray work and WAIT for processing to complete + logger.info( f"[{self.request.id}] PROCESS TASK: Submitting Ray processing for source='{source}', strategy='{chunking_strategy}', destination='{source_type}'") chunks_ref = actor.process_file.remote( source, @@ -211,10 +211,17 @@ def process( task_id=task_id, **params ) - # Persist chunks into Redis via Ray to decouple Celery + # Wait for Ray processing to complete (this keeps task in STARTED/"PROCESSING" state) + logger.info( + f"[{self.request.id}] PROCESS TASK: Waiting for Ray processing to complete...") + chunks = ray.get(chunks_ref) + logger.info( + f"[{self.request.id}] PROCESS TASK: Ray processing completed, got {len(chunks) if chunks else 0} chunks") + + # Persist chunks into Redis via Ray (fire-and-forget, don't block) redis_key = f"dp:{task_id}:chunks" - actor.store_chunks_in_redis.remote(redis_key, chunks_ref) - logger.debug( + actor.store_chunks_in_redis.remote(redis_key, chunks) + logger.info( f"[{self.request.id}] PROCESS TASK: Scheduled store_chunks_in_redis for key '{redis_key}'") end_time = time.time() @@ -229,7 +236,7 @@ def process( f"[{self.request.id}] PROCESS TASK: Processing from URL: {source}") # For URL source, core.py expects a non-local destination to trigger URL fetching - logger.debug( + logger.info( f"[{self.request.id}] PROCESS TASK: Submitting Ray processing for URL='{source}', strategy='{chunking_strategy}', destination='{source_type}'") chunks_ref = actor.process_file.remote( source, @@ -238,11 +245,19 @@ def process( task_id=task_id, **params ) - # Persist chunks into Redis via Ray to decouple Celery + # Wait for Ray processing to complete (this keeps task in STARTED/"PROCESSING" state) + logger.info( + f"[{self.request.id}] PROCESS TASK: Waiting for Ray processing to complete...") + chunks = ray.get(chunks_ref) + logger.info( + f"[{self.request.id}] PROCESS TASK: Ray processing completed, got {len(chunks) if chunks else 0} chunks") + + # Persist chunks into Redis via Ray (fire-and-forget, don't block) redis_key = f"dp:{task_id}:chunks" - actor.store_chunks_in_redis.remote(redis_key, chunks_ref) - logger.debug( + actor.store_chunks_in_redis.remote(redis_key, chunks) + logger.info( f"[{self.request.id}] PROCESS TASK: Scheduled store_chunks_in_redis for key '{redis_key}'") + end_time = time.time() elapsed_time = end_time - start_time logger.info( @@ -253,11 +268,12 @@ def process( raise NotImplementedError( f"Source type '{source_type}' not yet supported") - # Update task state to SUCCESS with metadata (without materializing chunks here) + # Update task state to SUCCESS after Ray processing completes + # This transitions from STARTED (PROCESSING) to SUCCESS (WAIT_FOR_FORWARDING) self.update_state( state=states.SUCCESS, meta={ - 'chunks_count': None, + 'chunks_count': len(chunks) if chunks else 0, 'processing_time': elapsed_time, 'source': source, 'index_name': index_name, @@ -265,12 +281,12 @@ def process( 'task_name': 'process', 'stage': 'text_extracted', 'file_size_mb': file_size_mb, - 'processing_speed_mb_s': file_size_mb / elapsed_time if elapsed_time > 0 else 0 + 'processing_speed_mb_s': file_size_mb / elapsed_time if file_size_mb > 0 and elapsed_time > 0 else 0 } ) logger.info( - f"[{self.request.id}] PROCESS TASK: Submitted for Ray processing; result will be fetched by forward") + f"[{self.request.id}] PROCESS TASK: Processing complete, waiting for forward task") # Prepare data for the next task in the chain; pass redis_key returned_data = { @@ -563,6 +579,9 @@ async def index_documents(): "source": original_source, "original_filename": original_filename }, ensure_ascii=False)) + + logger.info( + f"[{self.request.id}] FORWARD TASK: Starting ES indexing for {len(formatted_chunks)} chunks to index '{original_index_name}'...") es_result = run_async(index_documents()) logger.debug( f"[{self.request.id}] FORWARD TASK: API response from main_server for source '{original_source}': {es_result}") @@ -605,6 +624,8 @@ async def index_documents(): "original_filename": original_filename }, ensure_ascii=False)) end_time = time.time() + logger.info( + f"[{self.request.id}] FORWARD TASK: Updating task state to SUCCESS after ES indexing completion") self.update_state( state=states.SUCCESS, meta={ @@ -620,7 +641,7 @@ async def index_documents(): ) logger.info( - f"Stored {len(chunks)} chunks to index {original_index_name} in {end_time - start_time:.2f}s") + f"[{self.request.id}] FORWARD TASK: Successfully stored {len(chunks)} chunks to index {original_index_name} in {end_time - start_time:.2f}s") return { 'task_id': task_id, 'source': original_source, diff --git a/backend/database/db_models.py b/backend/database/db_models.py index 7087a2ae..7acb836e 100644 --- a/backend/database/db_models.py +++ b/backend/database/db_models.py @@ -206,6 +206,8 @@ class AgentInfo(TableBase): Boolean, doc="Whether to provide the running summary to the manager agent") business_description = Column( Text, doc="Manually entered by the user to describe the entire business process") + business_logic_model_name = Column(String(100), doc="Model name used for business logic prompt generation") + business_logic_model_id = Column(Integer, doc="Model ID used for business logic prompt generation, foreign key reference to model_record_t.model_id") class ToolInstance(TableBase): diff --git a/backend/prompts/cluster_summary_agent.yaml b/backend/prompts/cluster_summary_agent.yaml new file mode 100644 index 00000000..ed614ed0 --- /dev/null +++ b/backend/prompts/cluster_summary_agent.yaml @@ -0,0 +1,24 @@ +system_prompt: |- + You are a professional knowledge summarization assistant. Your task is to generate a concise summary of a document cluster based on multiple documents. + + **Summary Requirements:** + 1. The input contains multiple documents (each document has title and content snippets) + 2. You need to extract the common themes and key topics from these documents + 3. Generate a summary that represents the collective content of the cluster + 4. The summary should be accurate, coherent, and written in natural language + 5. Keep the summary within the specified word limit + + **Guidelines:** + - Focus on identifying shared themes and topics across documents + - Highlight key concepts, domains, or subject matter + - Use clear and concise language + - Avoid listing individual document titles unless necessary + - The summary should help users understand what this group of documents covers + +user_prompt: | + Please generate a concise summary of the following document cluster: + + {{ cluster_content }} + + Summary ({{ max_words }} words): + diff --git a/backend/prompts/cluster_summary_reduce.yaml b/backend/prompts/cluster_summary_reduce.yaml new file mode 100644 index 00000000..ece36081 --- /dev/null +++ b/backend/prompts/cluster_summary_reduce.yaml @@ -0,0 +1,31 @@ +system_prompt: |- + You are a professional cluster summarization assistant. Your task is to merge multiple document summaries into a cohesive cluster summary. + + **Summary Requirements:** + 1. The input contains summaries of multiple documents that belong to the same cluster + 2. These documents share similar themes or topics (grouped by clustering) + 3. You need to synthesize a unified summary that captures the collective content + 4. The summary should highlight common themes and key information across documents + 5. Keep the summary within the specified word limit + + **Guidelines:** + - Identify shared themes and topics across documents + - Highlight common concepts and subject matter + - Use clear and concise language + - Avoid listing individual document titles unless necessary + - Focus on what this group of documents collectively covers + - The summary should be coherent and represent the cluster's unified content + - **Important: Do not use any separators (like ---, ***, etc.), generate plain text summary only** + +user_prompt: | + Please generate a unified summary of the following document cluster based on individual document summaries: + + {{ document_summaries }} + + **Important Reminders:** + - Do not use any separators (like ---, ***, ===, etc.) + - Do not include document titles or filenames + - Generate plain text summary content only + + Cluster Summary ({{ max_words }} words): + diff --git a/backend/prompts/cluster_summary_reduce_zh.yaml b/backend/prompts/cluster_summary_reduce_zh.yaml new file mode 100644 index 00000000..f6ef4a64 --- /dev/null +++ b/backend/prompts/cluster_summary_reduce_zh.yaml @@ -0,0 +1,32 @@ +system_prompt: |- + 你是一个专业的簇总结助手。你的任务是将多个文档总结合并为一个连贯的簇总结。 + + **总结要求:** + 1. 输入包含属于同一簇的多个文档的总结 + 2. 这些文档共享相似的主题或话题(通过聚类分组) + 3. 你需要综合成一个统一的总结,捕捉集合内容 + 4. 总结应突出文档间的共同主题和关键信息 + 5. 保持在指定的字数限制内 + + **指导原则:** + - 识别文档间的共同主题和话题 + - 突出共同概念和主题内容 + - 使用清晰简洁的语言 + - 除非必要,避免列出单个文档标题 + - 专注于这组文档共同涵盖的内容 + - 总结应连贯且代表簇的统一内容 + - 确保准确、全面,明确关键实体,不要遗漏重要信息 + - **重要:不要使用任何分隔符(如---、***等),直接生成纯文本总结** + +user_prompt: | + 请根据以下文档总结生成统一的学生簇总结: + + {{ document_summaries }} + + **重要提醒:** + - 不要使用任何分隔符(如---、***、===等) + - 不要包含文档标题或文件名 + - 直接生成纯文本总结内容 + + 簇总结({{ max_words }}字): + diff --git a/backend/prompts/document_summary_agent.yaml b/backend/prompts/document_summary_agent.yaml new file mode 100644 index 00000000..88b4d9a9 --- /dev/null +++ b/backend/prompts/document_summary_agent.yaml @@ -0,0 +1,28 @@ +system_prompt: |- + You are a professional document summarization assistant. Your task is to generate a concise summary of a document based on its key content snippets. + + **Summary Requirements:** + 1. The input contains key snippets from a document (typically from beginning, middle, and end sections) + 2. You need to extract the main themes, topics, and key information + 3. Generate a summary that represents the document's core content + 4. The summary should be accurate, coherent, and concise + 5. Keep the summary within the specified word limit + + **Guidelines:** + - Focus on identifying main themes and key topics + - Highlight important concepts and information + - Use clear and concise language + - Avoid redundancy and unnecessary details + - The summary should help users understand what the document covers + - **Important: Do not use any separators (like ---, ***, etc.), generate plain text summary only** + +user_prompt: | + Please generate a concise summary of the following document: + + Document name: {{ filename }} + + Content snippets: + {{ content }} + + Summary ({{ max_words }} words): + diff --git a/backend/prompts/document_summary_agent_zh.yaml b/backend/prompts/document_summary_agent_zh.yaml new file mode 100644 index 00000000..4f443ca3 --- /dev/null +++ b/backend/prompts/document_summary_agent_zh.yaml @@ -0,0 +1,29 @@ +system_prompt: |- + 你是一个专业的文档总结助手。你的任务是根据文档的关键内容片段生成简洁的总结。 + + **总结要求:** + 1. 输入包含文档的关键片段(通常来自开头、中间和结尾部分) + 2. 你需要提取主要主题、话题和关键信息 + 3. 生成能代表文档核心内容的总结 + 4. 总结应准确、连贯且简洁 + 5. 保持在指定的字数限制内 + + **指导原则:** + - 专注于识别主要主题和关键话题 + - 突出重要概念和信息 + - 使用清晰简洁的语言 + - 避免冗余和不必要的细节 + - 总结应帮助用户理解文档涵盖的内容 + - 确保总结准确、全面,不要遗漏关键实体和信息 + - **重要:不要使用任何分隔符(如---、***等),直接生成纯文本总结** + +user_prompt: | + 请为以下文档生成简洁的总结: + + 文档名称:{{ filename }} + + 内容片段: + {{ content }} + + 总结({{ max_words }}字): + diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 777ca3cd..bc4187be 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -14,7 +14,9 @@ dependencies = [ "pyyaml>=6.0.2", "redis>=5.0.0", "fastmcp==2.12.0", - "langchain>=0.3.26" + "langchain>=0.3.26", + "scikit-learn>=1.0.0", + "numpy>=1.24.0" ] [project.optional-dependencies] diff --git a/backend/services/agent_service.py b/backend/services/agent_service.py index 6da1aeaf..a8ae173a 100644 --- a/backend/services/agent_service.py +++ b/backend/services/agent_service.py @@ -236,6 +236,13 @@ async def get_agent_info_impl(agent_id: int, tenant_id: str): else: agent_info["model_name"] = None + # Get business logic model display name from model_id + if agent_info.get("business_logic_model_id") is not None: + business_logic_model_info = get_model_by_model_id(agent_info["business_logic_model_id"]) + agent_info["business_logic_model_name"] = business_logic_model_info.get("display_name", None) if business_logic_model_info is not None else None + elif "business_logic_model_name" not in agent_info: + agent_info["business_logic_model_name"] = None + return agent_info diff --git a/backend/services/elasticsearch_service.py b/backend/services/elasticsearch_service.py index aa386d02..5193c2e4 100644 --- a/backend/services/elasticsearch_service.py +++ b/backend/services/elasticsearch_service.py @@ -18,14 +18,11 @@ from fastapi import Body, Depends, Path, Query from fastapi.responses import StreamingResponse -from jinja2 import Template, StrictUndefined from nexent.core.models.embedding_model import OpenAICompatibleEmbedding, JinaEmbedding, BaseEmbedding from nexent.core.nlp.tokenizer import calculate_term_weights from nexent.vector_database.elasticsearch_core import ElasticSearchCore -from openai import OpenAI -from openai.types.chat import ChatCompletionMessageParam -from consts.const import ES_API_KEY, ES_HOST, LANGUAGE, MODEL_CONFIG_MAPPING, MESSAGE_ROLE, KNOWLEDGE_SUMMARY_MAX_TOKENS_ZH, KNOWLEDGE_SUMMARY_MAX_TOKENS_EN +from consts.const import ES_API_KEY, ES_HOST, LANGUAGE from database.attachment_db import delete_file from database.knowledge_db import ( create_knowledge_record, @@ -36,7 +33,6 @@ from services.redis_service import get_redis_service from utils.config_utils import tenant_config_manager, get_model_name_from_config from utils.file_management_utils import get_all_files_status, get_file_size -from utils.prompt_template_utils import get_knowledge_summary_prompt_template # Configure logging logger = logging.getLogger("elasticsearch_service") @@ -44,89 +40,8 @@ -def generate_knowledge_summary_stream(keywords: str, language: str, tenant_id: str, model_id: Optional[int] = None) -> Generator: - """ - Generate a knowledge base summary based on keywords - - Args: - keywords: Keywords that frequently appear in the knowledge base content - language: Language of the knowledge base content - tenant_id: The tenant ID for configuration - - Returns: - str: Generate a knowledge base summary - """ - # Load prompt words based on language - prompts = get_knowledge_summary_prompt_template(language) - - # Render templates using Jinja2 - system_prompt = Template( - prompts['system_prompt'], undefined=StrictUndefined).render({}) - user_prompt = Template(prompts['user_prompt'], undefined=StrictUndefined).render( - {'content': keywords}) - - # Build messages - messages: List[ChatCompletionMessageParam] = [ - {"role": MESSAGE_ROLE["SYSTEM"], "content": system_prompt}, - {"role": MESSAGE_ROLE["USER"], "content": user_prompt} - ] - - # Get model configuration - if model_id: - try: - from database.model_management_db import get_model_by_model_id - model_info = get_model_by_model_id(model_id, tenant_id) - if model_info: - model_config = { - 'api_key': model_info.get('api_key', ''), - 'base_url': model_info.get('base_url', ''), - 'model_name': model_info.get('model_name', ''), - 'model_repo': model_info.get('model_repo', '') - } - else: - # Fallback to default model if specified model not found - logger.warning(f"Specified model {model_id} not found, falling back to default LLM.") - model_config = tenant_config_manager.get_model_config( - key=MODEL_CONFIG_MAPPING["llm"], tenant_id=tenant_id) - except Exception as e: - logger.warning(f"Failed to get model {model_id}, using default model: {e}") - model_config = tenant_config_manager.get_model_config( - key=MODEL_CONFIG_MAPPING["llm"], tenant_id=tenant_id) - else: - # Use default model configuration - model_config = tenant_config_manager.get_model_config( - key=MODEL_CONFIG_MAPPING["llm"], tenant_id=tenant_id) - - # initialize OpenAI client - client = OpenAI(api_key=model_config.get('api_key', ""), - base_url=model_config.get('base_url', "")) - - try: - # Create stream chat completion request - max_tokens = KNOWLEDGE_SUMMARY_MAX_TOKENS_ZH if language == LANGUAGE[ - "ZH"] else KNOWLEDGE_SUMMARY_MAX_TOKENS_EN - # Get model name for the request - model_name_for_request = model_config.get("model_name", "") - if model_config.get("model_repo"): - model_name_for_request = f"{model_config['model_repo']}/{model_name_for_request}" - - stream = client.chat.completions.create( - model=model_name_for_request, - messages=messages, - max_tokens=max_tokens, # add max_tokens limit - stream=True # enable stream output - ) - - # Iterate through stream response - for chunk in stream: - new_token = chunk.choices[0].delta.content - if new_token is not None: - yield new_token - yield "END" - - except Exception as e: - logger.error(f"Error occurred: {str(e)}") - yield f"Error: {str(e)}" +# Old keyword-based summary method removed - replaced with Map-Reduce approach +# See utils/document_vector_utils.py for new implementation # Initialize ElasticSearchCore instance with HTTPS support @@ -871,62 +786,85 @@ async def summary_index_name(self, model_id: Optional[int] = None ): """ - Generate a summary for the specified index based on its content + Generate a summary for the specified index using advanced Map-Reduce approach + + New implementation: + 1. Get documents and cluster them by semantic similarity + 2. Map: Summarize each document individually + 3. Reduce: Merge document summaries into cluster summaries + 4. Return: Combined knowledge base summary Args: index_name: Name of the index to summarize - batch_size: Number of documents to process per batch + batch_size: Number of documents to sample (default: 1000) es_core: ElasticSearchCore instance tenant_id: ID of the tenant language: Language of the summary (default: 'zh') + model_id: Model ID for LLM summarization Returns: StreamingResponse containing the generated summary """ try: - # Get all documents + from utils.document_vector_utils import ( + process_documents_for_clustering, + kmeans_cluster_documents, + summarize_clusters_map_reduce, + merge_cluster_summaries + ) + if not tenant_id: - raise Exception( - "Tenant ID is required for summary generation.") - all_documents = ElasticSearchService.get_random_documents( - index_name, batch_size, es_core) - all_chunks = self._clean_chunks_for_summary(all_documents) - keywords_dict = calculate_term_weights(all_chunks) - keywords_for_summary = "" - for _, key in enumerate(keywords_dict): - keywords_for_summary = keywords_for_summary + ", " + key - + raise Exception("Tenant ID is required for summary generation.") + + # Use new Map-Reduce approach + sample_count = min(batch_size // 5, 200) # Sample reasonable number of documents + + # Step 1: Get documents and calculate embeddings + document_samples, doc_embeddings = process_documents_for_clustering( + index_name=index_name, + es_core=es_core, + sample_doc_count=sample_count + ) + + if not document_samples: + raise Exception("No documents found in index.") + + # Step 2: Cluster documents + clusters = kmeans_cluster_documents(doc_embeddings, k=None) + + # Step 3: Map-Reduce summarization + cluster_summaries = summarize_clusters_map_reduce( + document_samples=document_samples, + clusters=clusters, + language=language, + doc_max_words=100, + cluster_max_words=150, + model_id=model_id, + tenant_id=tenant_id + ) + + # Step 4: Merge into final summary + final_summary = merge_cluster_summaries(cluster_summaries) + + # Stream the result async def generate_summary(): - token_join = [] try: - for new_token in generate_knowledge_summary_stream(keywords_for_summary, language, tenant_id, model_id): - if new_token == "END": - break - else: - token_join.append(new_token) - yield f"data: {{\"status\": \"success\", \"message\": \"{new_token}\"}}\n\n" - await asyncio.sleep(0.1) + # Stream the summary character by character + for char in final_summary: + yield f"data: {{\"status\": \"success\", \"message\": \"{char}\"}}\n\n" + await asyncio.sleep(0.01) + yield f"data: {{\"status\": \"completed\"}}\n\n" except Exception as e: yield f"data: {{\"status\": \"error\", \"message\": \"{e}\"}}\n\n" - - # Return the flow response + return StreamingResponse( generate_summary(), media_type="text/event-stream" ) - + except Exception as e: - raise Exception(f"{str(e)}") - - @staticmethod - def _clean_chunks_for_summary(all_documents): - # Only use these three fields for summarization - all_chunks = "" - for _, chunk in enumerate(all_documents['documents']): - all_chunks = all_chunks + "\n" + \ - chunk["title"] + "\n" + chunk["filename"] + \ - "\n" + chunk["content"] - return all_chunks + logger.error(f"Knowledge base summary generation failed: {str(e)}", exc_info=True) + raise Exception(f"Failed to generate summary: {str(e)}") @staticmethod def get_random_documents( diff --git a/backend/utils/document_vector_utils.py b/backend/utils/document_vector_utils.py new file mode 100644 index 00000000..5db8c215 --- /dev/null +++ b/backend/utils/document_vector_utils.py @@ -0,0 +1,786 @@ +""" +Document Vector Utilities Module + +This module provides utilities for document-level vector operations and clustering. +Main features: +1. Document-level vector calculation (weighted average of chunk vectors) +2. Automatic K-means clustering with optimal K determination +3. Document grouping and classification +4. Cluster summarization +""" +import logging +import random +from typing import Dict, List, Optional, Tuple + +import numpy as np +import yaml +from jinja2 import Template, StrictUndefined +from sklearn.cluster import KMeans +from sklearn.metrics import silhouette_score + +from consts.const import LANGUAGE + +logger = logging.getLogger("document_vector_utils") + + +def get_documents_from_es(index_name: str, es_core, sample_doc_count: int = 200) -> Dict[str, Dict]: + """ + Get document samples from Elasticsearch, aggregated by path_or_url + + Args: + index_name: Name of the index to query + es_core: ElasticSearchCore instance + sample_doc_count: Number of documents to sample + + Returns: + Dictionary mapping document IDs to document information with chunks + """ + try: + # Step 1: Aggregate unique documents by path_or_url + agg_query = { + "size": 0, + "aggs": { + "unique_documents": { + "terms": { + "field": "path_or_url", + "size": 10000 # Get all unique documents + } + } + } + } + + logger.info(f"Fetching unique documents from index {index_name}") + agg_response = es_core.client.search(index=index_name, body=agg_query) + all_documents = agg_response['aggregations']['unique_documents']['buckets'] + + if not all_documents: + logger.warning(f"No documents found in index {index_name}") + return {} + + # Step 2: Random sample documents + sample_count = min(sample_doc_count, len(all_documents)) + # Ensure all_documents is a list for random.sample + if not isinstance(all_documents, list): + all_documents = list(all_documents) + sampled_docs = random.sample(all_documents, sample_count) + + logger.info(f"Sampled {sample_count} documents from {len(all_documents)} total documents") + + # Step 3: Get all chunks for each sampled document + document_samples = {} + for doc_bucket in sampled_docs: + path_or_url = doc_bucket['key'] + chunk_count = doc_bucket['doc_count'] + + # Get all chunks for this document + chunks_query = { + "query": { + "term": {"path_or_url": path_or_url} + }, + "size": chunk_count # Get all chunks + } + + chunks_response = es_core.client.search(index=index_name, body=chunks_query) + chunks = [hit['_source'] for hit in chunks_response['hits']['hits']] + + # Build document object + if chunks: + doc_id = f"doc_{len(document_samples):04d}" + document_samples[doc_id] = { + "doc_id": doc_id, + "path_or_url": path_or_url, + "filename": chunks[0].get('filename', 'unknown'), + "chunk_count": chunk_count, + "chunks": chunks, + "file_size": chunks[0].get('file_size', 0) + } + + logger.info(f"Successfully retrieved {len(document_samples)} documents with chunks") + return document_samples + + except Exception as e: + logger.error(f"Error retrieving documents from ES: {str(e)}", exc_info=True) + raise Exception(f"Failed to retrieve documents from Elasticsearch: {str(e)}") + + +def calculate_document_embedding(doc_chunks: List[Dict], use_weighted: bool = True) -> Optional[np.ndarray]: + """ + Calculate document-level embedding from chunk embeddings + + Args: + doc_chunks: List of chunk dictionaries containing 'embedding' and 'content' fields + use_weighted: Whether to use weighted average based on content length + + Returns: + Document-level embedding vector or None if no valid embeddings found + """ + try: + embeddings = [] + weights = [] + + for chunk in doc_chunks: + chunk_embedding = chunk.get('embedding') + if chunk_embedding and isinstance(chunk_embedding, list): + embeddings.append(np.array(chunk_embedding)) + + if use_weighted: + # Weight by content length + content_length = len(chunk.get('content', '')) + position_weight = 1.5 if len(embeddings) == 1 else 1.0 # First chunk has higher weight + weight = position_weight * content_length + weights.append(weight) + + if not embeddings: + logger.warning("No valid embeddings found in chunks") + return None + + # Convert to numpy array + embeddings_array = np.array(embeddings) + + if use_weighted and weights: + # Weighted average + total_weight = sum(weights) + weights_normalized = np.array(weights) / total_weight + doc_embedding = np.average(embeddings_array, axis=0, weights=weights_normalized) + else: + # Simple average + doc_embedding = np.mean(embeddings_array, axis=0) + + return doc_embedding + + except Exception as e: + logger.error(f"Error calculating document embedding: {str(e)}", exc_info=True) + return None + + +def auto_determine_k(embeddings: np.ndarray, min_k: int = 3, max_k: int = 15) -> int: + """ + Automatically determine optimal K value for K-means clustering + + Args: + embeddings: Array of document embeddings + min_k: Minimum number of clusters + max_k: Maximum number of clusters + + Returns: + Optimal K value + """ + try: + n_samples = len(embeddings) + + # Handle edge cases + if n_samples < min_k: + return max(2, n_samples) + + if n_samples < 20: + # For small datasets, use simple heuristic + heuristic_k = max(min_k, min(int(np.sqrt(n_samples / 2)), max_k)) + return heuristic_k + + # Determine K range based on dataset size + actual_max_k = min(max_k, n_samples // 10, 15) # At least 10 samples per cluster + actual_min_k = min(min_k, actual_max_k) + + # Try different K values and calculate silhouette score + best_k = actual_min_k + best_score = -1 + + k_range = range(actual_min_k, actual_max_k + 1) + logger.info(f"Trying K values from {actual_min_k} to {actual_max_k}") + + for k in k_range: + try: + kmeans = KMeans(n_clusters=k, random_state=42, n_init=10, max_iter=300) + labels = kmeans.fit_predict(embeddings) + + # Calculate silhouette score + score = silhouette_score(embeddings, labels, sample_size=min(1000, n_samples)) + + logger.debug(f"K={k}, Silhouette Score={score:.4f}") + + if score > best_score: + best_score = score + best_k = k + + except Exception as e: + logger.warning(f"Error calculating K={k}: {str(e)}") + continue + + logger.info(f"Optimal K determined: {best_k} (Silhouette Score: {best_score:.4f})") + return best_k + + except Exception as e: + logger.error(f"Error in auto_determine_k: {str(e)}", exc_info=True) + # Fallback to heuristic + heuristic_k = max(min_k, min(int(np.sqrt(len(embeddings) / 2)), max_k)) + logger.warning(f"Using fallback K value: {heuristic_k}") + return heuristic_k + + +def kmeans_cluster_documents(doc_embeddings: Dict[str, np.ndarray], k: Optional[int] = None) -> Dict[int, List[str]]: + """ + Cluster documents using K-means + + Args: + doc_embeddings: Dictionary mapping document IDs to their embeddings + k: Number of clusters (if None, auto-determined) + + Returns: + Dictionary mapping cluster IDs to lists of document IDs + """ + try: + if not doc_embeddings: + logger.warning("No document embeddings provided") + return {} + + # Prepare embeddings array + doc_ids = list(doc_embeddings.keys()) + embeddings_array = np.array([doc_embeddings[doc_id] for doc_id in doc_ids]) + + # Handle single document case + if len(doc_ids) == 1: + logger.info("Only one document found, skipping clustering") + return {0: doc_ids} + + # Determine K value + if k is None: + k = auto_determine_k(embeddings_array) + + # Ensure k is not greater than number of documents + k = min(k, len(doc_ids)) + + logger.info(f"Clustering {len(doc_ids)} documents into {k} clusters") + + # Perform K-means clustering + kmeans = KMeans(n_clusters=k, random_state=42, n_init=10, max_iter=300) + labels = kmeans.fit_predict(embeddings_array) + + # Group documents by cluster + clusters = {} + for i, label in enumerate(labels): + if label not in clusters: + clusters[label] = [] + clusters[label].append(doc_ids[i]) + + # Log cluster sizes + for cluster_id, docs in clusters.items(): + logger.info(f"Cluster {cluster_id}: {len(docs)} documents") + + return clusters + + except Exception as e: + logger.error(f"Error in K-means clustering: {str(e)}", exc_info=True) + raise Exception(f"Failed to cluster documents: {str(e)}") + + +def process_documents_for_clustering(index_name: str, es_core, sample_doc_count: int = 200) -> Tuple[Dict[str, Dict], Dict[str, np.ndarray]]: + """ + Complete workflow: Get documents from ES and calculate their embeddings + + Args: + index_name: Name of the index to query + es_core: ElasticSearchCore instance + sample_doc_count: Number of documents to sample + + Returns: + Tuple of (document_samples dict, doc_embeddings dict) + """ + try: + # Step 1: Get documents from ES + document_samples = get_documents_from_es(index_name, es_core, sample_doc_count) + + if not document_samples: + logger.warning("No documents retrieved from ES") + return {}, {} + + # Step 2: Calculate document-level embeddings + doc_embeddings = {} + for doc_id, doc_info in document_samples.items(): + chunks = doc_info['chunks'] + doc_embedding = calculate_document_embedding(chunks, use_weighted=True) + + if doc_embedding is not None: + doc_embeddings[doc_id] = doc_embedding + else: + logger.warning(f"Failed to calculate embedding for document {doc_id}") + + logger.info(f"Successfully calculated embeddings for {len(doc_embeddings)} documents") + return document_samples, doc_embeddings + + except Exception as e: + logger.error(f"Error processing documents for clustering: {str(e)}", exc_info=True) + raise Exception(f"Failed to process documents: {str(e)}") + + +def extract_cluster_content(document_samples: Dict[str, Dict], cluster_doc_ids: List[str], max_chunks_per_doc: int = 3) -> str: + """ + Extract representative content from a cluster for summarization + + Args: + document_samples: Dictionary mapping doc_id to document info + cluster_doc_ids: List of document IDs in the cluster + max_chunks_per_doc: Maximum number of chunks to include per document + + Returns: + Formatted string containing cluster content + """ + cluster_content_parts = [] + + for doc_id in cluster_doc_ids: + if doc_id not in document_samples: + continue + + doc_info = document_samples[doc_id] + chunks = doc_info.get('chunks', []) + filename = doc_info.get('filename', 'unknown') + + # Extract representative chunks + representative_chunks = [] + if len(chunks) <= max_chunks_per_doc: + representative_chunks = chunks + else: + # Take first, middle, and last chunks + representative_chunks = ( + chunks[:1] + + chunks[len(chunks)//2:len(chunks)//2+1] + + chunks[-1:] + ) + + # Format document content + doc_content = f"\n--- Document: {filename} ---\n" + for chunk in representative_chunks: + content = chunk.get('content', '') + # Limit chunk content length + if len(content) > 500: + content = content[:500] + "..." + doc_content += f"{content}\n" + + cluster_content_parts.append(doc_content) + + return "\n".join(cluster_content_parts) + + +def summarize_document(document_content: str, filename: str, language: str = LANGUAGE["ZH"], max_words: int = 100, model_id: Optional[int] = None, tenant_id: Optional[str] = None) -> str: + """ + Summarize a single document using LLM (Map stage) + + Args: + document_content: Formatted content from document chunks + filename: Document filename + language: Language code ('zh' or 'en') + max_words: Maximum words in the summary + model_id: Model ID for LLM call + tenant_id: Tenant ID for model configuration + + Returns: + Document summary text + """ + try: + # Select prompt file based on language + if language == LANGUAGE["ZH"]: + prompt_path = 'backend/prompts/document_summary_agent_zh.yaml' + else: + prompt_path = 'backend/prompts/document_summary_agent.yaml' + + with open(prompt_path, 'r', encoding='utf-8') as f: + prompts = yaml.safe_load(f) + + system_prompt = prompts.get('system_prompt', '') + user_prompt_template = prompts.get('user_prompt', '') + + user_prompt = Template(user_prompt_template, undefined=StrictUndefined).render( + filename=filename, + content=document_content, + max_words=max_words + ) + + logger.info(f"Document summary prompt generated for {filename} (max_words: {max_words})") + + # Call LLM if model_id and tenant_id are provided + if model_id and tenant_id: + from smolagents import OpenAIServerModel + from database.model_management_db import get_model_by_model_id + from utils.config_utils import get_model_name_from_config + from consts.const import MESSAGE_ROLE + + # Get model configuration + llm_model_config = get_model_by_model_id(model_id=model_id, tenant_id=tenant_id) + if not llm_model_config: + logger.warning(f"No model configuration found for model_id: {model_id}, tenant_id: {tenant_id}") + return f"[Document Summary: {filename}] (max {max_words} words) - Content: {document_content[:200]}..." + + # Create LLM instance + llm = OpenAIServerModel( + model_id=get_model_name_from_config(llm_model_config) if llm_model_config else "", + api_base=llm_model_config.get("base_url", ""), + api_key=llm_model_config.get("api_key", ""), + temperature=0.3, + top_p=0.95 + ) + + # Build messages + messages = [ + {"role": MESSAGE_ROLE["SYSTEM"], "content": system_prompt}, + {"role": MESSAGE_ROLE["USER"], "content": user_prompt} + ] + + # Call LLM + response = llm(messages, max_tokens=max_words * 2) # Allow more tokens for generation + return response.content.strip() + else: + # Fallback to placeholder if no model configuration + logger.warning("No model_id or tenant_id provided, using placeholder summary") + return f"[Document Summary: {filename}] (max {max_words} words) - Content: {document_content[:200]}..." + + except Exception as e: + logger.error(f"Error generating document summary: {str(e)}", exc_info=True) + return f"Failed to generate summary for {filename}: {str(e)}" + + +def summarize_cluster(document_summaries: List[str], language: str = LANGUAGE["ZH"], max_words: int = 150, model_id: Optional[int] = None, tenant_id: Optional[str] = None) -> str: + """ + Summarize a cluster of documents using LLM (Reduce stage) + + Args: + document_summaries: List of individual document summaries + language: Language code ('zh' or 'en') + max_words: Maximum words in the summary + model_id: Model ID for LLM call + tenant_id: Tenant ID for model configuration + + Returns: + Cluster summary text + """ + try: + # Select prompt file based on language + if language == LANGUAGE["ZH"]: + prompt_path = 'backend/prompts/cluster_summary_reduce_zh.yaml' + else: + prompt_path = 'backend/prompts/cluster_summary_reduce.yaml' + + with open(prompt_path, 'r', encoding='utf-8') as f: + prompts = yaml.safe_load(f) + + system_prompt = prompts.get('system_prompt', '') + user_prompt_template = prompts.get('user_prompt', '') + + # Format document summaries + summaries_text = "\n\n".join([f"Document {i+1}: {summary}" for i, summary in enumerate(document_summaries)]) + + user_prompt = Template(user_prompt_template, undefined=StrictUndefined).render( + document_summaries=summaries_text, + max_words=max_words + ) + + logger.info(f"Cluster summary prompt generated (language: {language}, max_words: {max_words})") + + # Call LLM if model_id and tenant_id are provided + if model_id and tenant_id: + from smolagents import OpenAIServerModel + from database.model_management_db import get_model_by_model_id + from utils.config_utils import get_model_name_from_config + from consts.const import MESSAGE_ROLE + + # Get model configuration + llm_model_config = get_model_by_model_id(model_id=model_id, tenant_id=tenant_id) + if not llm_model_config: + logger.warning(f"No model configuration found for model_id: {model_id}, tenant_id: {tenant_id}") + return f"[Cluster Summary] (max {max_words} words) - Based on {len(document_summaries)} documents" + + # Create LLM instance + llm = OpenAIServerModel( + model_id=get_model_name_from_config(llm_model_config) if llm_model_config else "", + api_base=llm_model_config.get("base_url", ""), + api_key=llm_model_config.get("api_key", ""), + temperature=0.3, + top_p=0.95 + ) + + # Build messages + messages = [ + {"role": MESSAGE_ROLE["SYSTEM"], "content": system_prompt}, + {"role": MESSAGE_ROLE["USER"], "content": user_prompt} + ] + + # Call LLM + response = llm(messages, max_tokens=max_words * 2) # Allow more tokens for generation + return response.content.strip() + else: + # Fallback to placeholder if no model configuration + logger.warning("No model_id or tenant_id provided, using placeholder summary") + return f"[Cluster Summary] (max {max_words} words) - Based on {len(document_summaries)} documents" + + except Exception as e: + logger.error(f"Error generating cluster summary: {str(e)}", exc_info=True) + return f"Failed to generate summary: {str(e)}" + + +def extract_representative_chunks_smart(chunks: List[Dict], max_chunks: int = 3) -> List[Dict]: + """ + Intelligently extract representative chunks from a document + + Strategy: + 1. Always include first chunk (usually contains title/abstract) + 2. Extract chunks with highest keyword density (important content) + 3. Include last chunk if significant (may contain conclusions) + + Args: + chunks: List of chunk dictionaries with 'content' field + max_chunks: Maximum number of chunks to return + + Returns: + List of representative chunks + """ + if len(chunks) <= max_chunks: + return chunks + + selected_chunks = [] + + # 1. Always include first chunk + selected_chunks.append(chunks[0]) + + # 2. Find chunks with high keyword density + try: + from nexent.core.nlp.tokenizer import calculate_term_weights + except ImportError: + # Fallback: use simple scoring + logger.warning("Could not import calculate_term_weights, using simple scoring") + # Simple fallback: just pick middle chunks + if len(chunks) > 1: + selected_chunks.append(chunks[len(chunks)//2]) + if len(selected_chunks) < max_chunks and len(chunks) > 2: + selected_chunks.append(chunks[-1]) + return selected_chunks[:max_chunks] + + chunk_scores = [] + for i, chunk in enumerate(chunks[1:-1]): # Skip first and last + content = chunk.get('content', '') + if len(content) > 500: + # Calculate keyword density (use first 500 chars for speed) + keywords = calculate_term_weights(content[:500]) + score = len(keywords) * 0.5 + len(content) * 0.001 # Balance keyword count and length + chunk_scores.append((i + 1, score, chunk)) + + # Sort by score and pick top chunks + chunk_scores.sort(key=lambda x: x[1], reverse=True) + remaining_slots = max_chunks - 1 # Already have first chunk + + for idx, score, chunk in chunk_scores[:remaining_slots]: + selected_chunks.append(chunk) + + # 3. If we have space, include last chunk + if len(selected_chunks) < max_chunks and len(chunks) > 1: + selected_chunks.append(chunks[-1]) + + return selected_chunks[:max_chunks] + + +def merge_cluster_summaries(cluster_summaries: Dict[int, str]) -> str: + """ + Merge all cluster summaries into a final knowledge base summary + + Args: + cluster_summaries: Dictionary mapping cluster_id to cluster summary + + Returns: + Final merged knowledge base summary + """ + if not cluster_summaries: + return "" + + # Sort by cluster ID for consistent output + sorted_clusters = sorted(cluster_summaries.items()) + + # Format cluster summaries with HTML paragraph tags for explicit rendering + summary_parts = [] + for _, summary in sorted_clusters: + if summary.strip(): + # Wrap each summary in

tags for explicit paragraph rendering + summary_parts.append(f"

{summary.strip()}

") + + # Join with simple double newlines, as

tags already handle block-level separation + final_summary = "\n\n".join(summary_parts) + + logger.info(f"Merged {len(cluster_summaries)} cluster summaries into final knowledge base summary") + return final_summary + + +def analyze_cluster_coherence(cluster_doc_ids: List[str], document_samples: Dict[str, Dict]) -> Dict[str, any]: + """ + Analyze coherence and structure of documents within a cluster + + Returns: + Dict with analysis results including common themes, document types, etc. + """ + if not cluster_doc_ids: + return {} + + # Extract document titles and content previews + doc_previews = [] + for doc_id in cluster_doc_ids: + if doc_id in document_samples: + doc_info = document_samples[doc_id] + filename = doc_info.get('filename', 'unknown') + chunks = doc_info.get('chunks', []) + if chunks: + first_chunk = chunks[0].get('content', '')[:200] + doc_previews.append({'filename': filename, 'preview': first_chunk}) + + return { + 'doc_count': len(cluster_doc_ids), + 'doc_previews': doc_previews, + 'file_types': [doc['filename'].split('.')[-1] for doc in doc_previews if '.' in doc['filename']] + } + + +def summarize_clusters_map_reduce(document_samples: Dict[str, Dict], clusters: Dict[int, List[str]], + language: str = LANGUAGE["ZH"], doc_max_words: int = 100, cluster_max_words: int = 150, + use_smart_chunk_selection: bool = True, enhance_with_metadata: bool = True, + model_id: Optional[int] = None, tenant_id: Optional[str] = None) -> Dict[int, str]: + """ + Summarize all clusters using Map-Reduce approach + + Map stage: Summarize each document individually (within each cluster) + Reduce stage: Combine document summaries within the same cluster into a cluster summary + Note: Clusters remain separate - we combine document summaries WITHIN each cluster + + Args: + document_samples: Dictionary mapping doc_id to document info + clusters: Dictionary mapping cluster_id to list of doc_ids + language: Language code ('zh' or 'en') + doc_max_words: Maximum words per document summary + cluster_max_words: Maximum words per cluster summary + use_smart_chunk_selection: Use intelligent chunk selection based on keyword density + enhance_with_metadata: Enhance summaries with document metadata + model_id: Model ID for LLM calls + tenant_id: Tenant ID for model configuration + + Returns: + Dictionary mapping cluster_id to summary text + """ + cluster_summaries = {} + + for cluster_id, doc_ids in clusters.items(): + logger.info(f"Summarizing cluster {cluster_id} with {len(doc_ids)} documents using Map-Reduce") + + # Map stage: Summarize each document + document_summaries = [] + for doc_id in doc_ids: + if doc_id not in document_samples: + continue + + doc_info = document_samples[doc_id] + chunks = doc_info.get('chunks', []) + filename = doc_info.get('filename', 'unknown') + + # Extract representative content for this document + if use_smart_chunk_selection: + representative_chunks = extract_representative_chunks_smart(chunks, max_chunks=3) + else: + # Simple approach: first, middle, last + if len(chunks) <= 3: + representative_chunks = chunks + else: + representative_chunks = ( + chunks[:1] + + chunks[len(chunks)//2:len(chunks)//2+1] + + chunks[-1:] + ) + + # Format document content (merge top-K chunks) + doc_content = "" + for i, chunk in enumerate(representative_chunks): + content = chunk.get('content', '') + # Limit each chunk length for individual document + if len(content) > 1000: + content = content[:1000] + "..." + # Add chunk separator + doc_content += f"[Chunk {i+1}]\n{content}\n\n" + + # Generate document summary from merged chunks + logger.info(f"Summarizing document {filename} with {len(representative_chunks)} representative chunks") + doc_summary = summarize_document(doc_content, filename, language, doc_max_words, model_id, tenant_id) + document_summaries.append(doc_summary) + + # Reduce stage: Combine document summaries within this cluster into cluster summary + if document_summaries: + # Optionally enhance with cluster analysis + if enhance_with_metadata: + cluster_analysis = analyze_cluster_coherence(doc_ids, document_samples) + logger.info(f"Cluster {cluster_id} analysis: {cluster_analysis.get('doc_count', 0)} documents") + + cluster_summary = summarize_cluster(document_summaries, language, cluster_max_words, model_id, tenant_id) + cluster_summaries[cluster_id] = cluster_summary + else: + logger.warning(f"No valid documents found in cluster {cluster_id}") + cluster_summaries[cluster_id] = "No content available for this cluster" + + return cluster_summaries + + +def summarize_clusters(document_samples: Dict[str, Dict], clusters: Dict[int, List[str]], + language: str = LANGUAGE["ZH"], max_words: int = 150) -> Dict[int, str]: + """ + Summarize all clusters (legacy method - kept for backward compatibility) + + Note: This method uses the old approach. Use summarize_clusters_map_reduce for better results. + + Args: + document_samples: Dictionary mapping doc_id to document info + clusters: Dictionary mapping cluster_id to list of doc_ids + language: Language code ('zh' or 'en') + max_words: Maximum words per cluster summary + + Returns: + Dictionary mapping cluster_id to summary text + """ + cluster_summaries = {} + + for cluster_id, doc_ids in clusters.items(): + logger.info(f"Summarizing cluster {cluster_id} with {len(doc_ids)} documents") + + # Extract cluster content + cluster_content = extract_cluster_content(document_samples, doc_ids, max_chunks_per_doc=3) + + # Generate summary using old method + summary = summarize_cluster_legacy(cluster_content, language, max_words) + cluster_summaries[cluster_id] = summary + + return cluster_summaries + + +def summarize_cluster_legacy(cluster_content: str, language: str = LANGUAGE["ZH"], max_words: int = 150) -> str: + """ + Legacy cluster summarization method (single-stage) + + Args: + cluster_content: Formatted content from the cluster + language: Language code ('zh' or 'en') + max_words: Maximum words in the summary + + Returns: + Cluster summary text + """ + try: + prompt_path = 'backend/prompts/cluster_summary_agent.yaml' + with open(prompt_path, 'r', encoding='utf-8') as f: + prompts = yaml.safe_load(f) + + system_prompt = prompts.get('system_prompt', '') + user_prompt_template = prompts.get('user_prompt', '') + + user_prompt = Template(user_prompt_template, undefined=StrictUndefined).render( + cluster_content=cluster_content, + max_words=max_words + ) + + logger.info(f"Cluster summary prompt generated (language: {language}, max_words: {max_words})") + + # Note: This is a legacy function, using placeholder summary + # The main summarization uses summarize_cluster() with LLM integration + return f"[Cluster Summary] (max {max_words} words) - Content preview: {cluster_content[:200]}..." + + except Exception as e: + logger.error(f"Error generating cluster summary: {str(e)}", exc_info=True) + return f"Failed to generate summary: {str(e)}" + diff --git a/docker/docker-compose.dev.yml b/docker/docker-compose.dev.yml index 182ae806..cfb20f6e 100644 --- a/docker/docker-compose.dev.yml +++ b/docker/docker-compose.dev.yml @@ -1,38 +1,38 @@ name: nexent services: - nexent: - image: nexent/nexent:latest - container_name: nexent - restart: always - ports: - - "5010:5010" - - "5013:5013" - volumes: - - ../:/opt/ - - /opt/backend/.venv/ - - ${NEXENT_USER_DIR:-$HOME/nexent}:/mnt/nexent - environment: - skip_proxy: "true" - UMASK: 0022 - env_file: - - .env - user: root - logging: - driver: "json-file" - options: - max-size: "10m" - max-file: "3" - networks: - - nexent - entrypoint: "/bin/bash" - command: - - -c - - | - rm -rf /var/lib/apt/lists/* && - echo "Python environment activated: $(which python)" && - echo "Python version: $(python --version)" && - tail -f /dev/null +# nexent: +# image: nexent/nexent:latest +# container_name: nexent +# restart: always +# ports: +# - "5010:5010" +# - "5013:5013" +# volumes: +# - ../:/opt/ +# - /opt/backend/.venv/ +# - ${NEXENT_USER_DIR:-$HOME/nexent}:/mnt/nexent +# environment: +# skip_proxy: "true" +# UMASK: 0022 +# env_file: +# - .env +# user: root +# logging: +# driver: "json-file" +# options: +# max-size: "10m" +# max-file: "3" +# networks: +# - nexent +# entrypoint: "/bin/bash" +# command: +# - -c +# - | +# rm -rf /var/lib/apt/lists/* && +# echo "Python environment activated: $(which python)" && +# echo "Python version: $(python --version)" && +# tail -f /dev/null nexent-data-process: @@ -45,7 +45,7 @@ services: volumes: - ../:/opt/:cached - /opt/backend/.venv/ - - ${NEXENT_USER_DIR:-$HOME/nexent}:/mnt/nexent + - ${ROOT_DIR}:/mnt/nexent-data environment: skip_proxy: "true" PATH: "/usr/local/bin:/usr/bin/:/opt/backend/.venv/bin:${PATH}" @@ -70,27 +70,27 @@ services: echo "Python version: $(python --version)" && python -c "import time; time.sleep(2147483647)" - nexent-web: - image: nexent/nexent-web:latest - container_name: nexent-web - restart: always - networks: - - nexent - ports: - - "3000:3000" - volumes: - - ../frontend:/opt/frontend:cached - - ../frontend/node_modules:/opt/frontend/node_modules:cached - environment: - - HTTP_BACKEND=http://nexent:5010 - - WS_BACKEND=ws://nexent:5010 - - MINIO_ENDPOINT=${MINIO_ENDPOINT} - logging: - driver: "json-file" - options: - max-size: "10m" - max-file: "3" - command: ["/bin/sh", "-c", "echo 'Web Service needs to be started manually. Use\nnpm install -g pnpm\npnpm install\npnpm dev\n under /opt/frontend to start.' && tail -f /dev/null"] +# nexent-web: +# image: nexent/nexent-web:latest +# container_name: nexent-web +# restart: always +# networks: +# - nexent +# ports: +# - "3000:3000" +# volumes: +# - ../frontend:/opt/frontend:cached +# - ../frontend/node_modules:/opt/frontend/node_modules:cached +# environment: +# - HTTP_BACKEND=http://nexent:5010 +# - WS_BACKEND=ws://nexent:5010 +# - MINIO_ENDPOINT=${MINIO_ENDPOINT} +# logging: +# driver: "json-file" +# options: +# max-size: "10m" +# max-file: "3" +# command: ["/bin/sh", "-c", "echo 'Web Service needs to be started manually. Use\nnpm install -g pnpm\npnpm install\npnpm dev\n under /opt/frontend to start.' && tail -f /dev/null"] networks: diff --git a/docker/init.sql b/docker/init.sql index d23e1c7f..4d19084e 100644 --- a/docker/init.sql +++ b/docker/init.sql @@ -290,6 +290,8 @@ CREATE TABLE IF NOT EXISTS nexent.ag_tenant_agent_t ( business_description VARCHAR, model_name VARCHAR(100), model_id INTEGER, + business_logic_model_name VARCHAR(100), + business_logic_model_id INTEGER, max_steps INTEGER, duty_prompt TEXT, constraint_prompt TEXT, @@ -330,6 +332,8 @@ COMMENT ON COLUMN nexent.ag_tenant_agent_t.description IS 'Description'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.business_description IS 'Manually entered by the user to describe the entire business process'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.model_name IS '[DEPRECATED] Name of the model used, use model_id instead'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.model_id IS 'Model ID, foreign key reference to model_record_t.model_id'; +COMMENT ON COLUMN nexent.ag_tenant_agent_t.business_logic_model_name IS 'Model name used for business logic prompt generation'; +COMMENT ON COLUMN nexent.ag_tenant_agent_t.business_logic_model_id IS 'Model ID used for business logic prompt generation, foreign key reference to model_record_t.model_id'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.max_steps IS 'Maximum number of steps'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.duty_prompt IS 'Duty prompt'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.constraint_prompt IS 'Constraint prompt'; @@ -344,8 +348,6 @@ COMMENT ON COLUMN nexent.ag_tenant_agent_t.created_by IS 'Creator'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.updated_by IS 'Updater'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.delete_flag IS 'Whether it is deleted. Optional values: Y/N'; --- Add comments to the columns -COMMENT ON COLUMN nexent.ag_tenant_agent_t.provide_run_summary IS 'Whether to provide the running summary to the manager agent'; -- Create the ag_tool_instance_t table in the nexent schema CREATE TABLE IF NOT EXISTS nexent.ag_tool_instance_t ( @@ -644,4 +646,4 @@ $$ LANGUAGE plpgsql; CREATE TRIGGER "update_partner_mapping_update_time_trigger" BEFORE UPDATE ON "nexent"."partner_mapping_id_t" FOR EACH ROW -EXECUTE FUNCTION "update_partner_mapping_update_time"(); \ No newline at end of file +EXECUTE FUNCTION "update_partner_mapping_update_time"(); diff --git a/docker/sql/1024_add_business_logic_model_fields.sql b/docker/sql/1024_add_business_logic_model_fields.sql new file mode 100644 index 00000000..ff1a7673 --- /dev/null +++ b/docker/sql/1024_add_business_logic_model_fields.sql @@ -0,0 +1,12 @@ +-- Add business_logic_model_name and business_logic_model_id fields to ag_tenant_agent_t table +-- These fields store the LLM model used for generating business logic prompts + +ALTER TABLE nexent.ag_tenant_agent_t +ADD COLUMN IF NOT EXISTS business_logic_model_name VARCHAR(100); + +ALTER TABLE nexent.ag_tenant_agent_t +ADD COLUMN IF NOT EXISTS business_logic_model_id INTEGER; + +COMMENT ON COLUMN nexent.ag_tenant_agent_t.business_logic_model_name IS 'Model name used for business logic prompt generation'; +COMMENT ON COLUMN nexent.ag_tenant_agent_t.business_logic_model_id IS 'Model ID used for business logic prompt generation, foreign key reference to model_record_t.model_id'; + diff --git a/frontend/app/[locale]/chat/streaming/chatStreamMain.tsx b/frontend/app/[locale]/chat/streaming/chatStreamMain.tsx index 6f4d77e9..dde39fed 100644 --- a/frontend/app/[locale]/chat/streaming/chatStreamMain.tsx +++ b/frontend/app/[locale]/chat/streaming/chatStreamMain.tsx @@ -398,6 +398,46 @@ export function ChatStreamMain({ shouldScrollToBottom, ]); + // Additional scroll trigger for async content like Mermaid diagrams + useEffect(() => { + if (processedMessages.finalMessages.length > 0 && autoScroll) { + const scrollAreaElement = scrollAreaRef.current?.querySelector( + "[data-radix-scroll-area-viewport]" + ); + if (!scrollAreaElement) return; + + // Use ResizeObserver to detect when content height changes (e.g., Mermaid diagrams finish rendering) + const resizeObserver = new ResizeObserver(() => { + const { scrollTop, scrollHeight, clientHeight } = + scrollAreaElement as HTMLElement; + const distanceToBottom = scrollHeight - scrollTop - clientHeight; + + // Auto-scroll if user is near bottom and content height changed + if (distanceToBottom < 100) { + scrollToBottom(); + } + }); + + resizeObserver.observe(scrollAreaElement); + + // Also use a timeout as fallback for async content + const timeoutId = setTimeout(() => { + const { scrollTop, scrollHeight, clientHeight } = + scrollAreaElement as HTMLElement; + const distanceToBottom = scrollHeight - scrollTop - clientHeight; + + if (distanceToBottom < 100) { + scrollToBottom(); + } + }, 1000); // Wait 1 second for async content to render + + return () => { + resizeObserver.disconnect(); + clearTimeout(timeoutId); + }; + } + }, [processedMessages.finalMessages.length, autoScroll]); + // Scroll to bottom when task messages are updated useEffect(() => { if (autoScroll) { diff --git a/frontend/app/[locale]/chat/streaming/taskWindow.tsx b/frontend/app/[locale]/chat/streaming/taskWindow.tsx index c8bd3f61..4067c992 100644 --- a/frontend/app/[locale]/chat/streaming/taskWindow.tsx +++ b/frontend/app/[locale]/chat/streaming/taskWindow.tsx @@ -664,6 +664,7 @@ const messageHandlers: MessageHandler[] = [ ), @@ -757,6 +758,7 @@ const messageHandlers: MessageHandler[] = [ ); } else { @@ -1061,7 +1063,8 @@ export function TaskWindow({ messages, isStreaming = false }: TaskWindowProps) { const maxHeight = 300; const headerHeight = 55; const availableHeight = maxHeight - headerHeight; - const actualContentHeight = Math.min(contentHeight + 16, availableHeight); + // Add extra padding for diagrams to prevent bottom cutoff + const actualContentHeight = Math.min(contentHeight + 32, availableHeight); const containerHeight = isExpanded ? headerHeight + actualContentHeight : "auto"; @@ -1096,15 +1099,15 @@ export function TaskWindow({ messages, isStreaming = false }: TaskWindowProps) { {isExpanded && ( -

+
{needsScroll ? ( -
+
{renderMessages()}
) : ( -
+
{renderMessages()}
)} @@ -1183,6 +1186,36 @@ export function TaskWindow({ messages, isStreaming = false }: TaskWindowProps) { box-sizing: border-box !important; } + /* Override diagram size in task window */ + .task-message-content .my-4 { + max-width: 200px !important; + margin: 0 auto !important; + display: flex !important; + justify-content: center !important; + } + + .task-message-content .my-4 img { + max-width: 200px !important; + width: 200px !important; + margin: 0 auto !important; + display: block !important; + } + + /* More specific selectors for mermaid diagrams */ + .task-message-content .task-message-content .my-4 { + max-width: 200px !important; + margin: 0 auto !important; + display: flex !important; + justify-content: center !important; + } + + .task-message-content .task-message-content .my-4 img { + max-width: 200px !important; + width: 200px !important; + margin: 0 auto !important; + display: block !important; + } + /* Paragraph spacing adjustment */ .task-message-content p { margin-bottom: 0.5rem !important; diff --git a/frontend/app/[locale]/setup/agents/components/AgentSetupOrchestrator.tsx b/frontend/app/[locale]/setup/agents/components/AgentSetupOrchestrator.tsx index 0a20afb6..64b6bb48 100644 --- a/frontend/app/[locale]/setup/agents/components/AgentSetupOrchestrator.tsx +++ b/frontend/app/[locale]/setup/agents/components/AgentSetupOrchestrator.tsx @@ -42,6 +42,10 @@ export default function AgentSetupOrchestrator({ setMainAgentModelId, mainAgentMaxStep, setMainAgentMaxStep, + businessLogicModel, + setBusinessLogicModel, + businessLogicModelId, + setBusinessLogicModelId, tools, subAgentList = [], loadingAgents = false, @@ -212,8 +216,6 @@ export default function AgentSetupOrchestrator({ if (!isEditingAgent) { // Only clear and get new Agent configuration in creating mode setBusinessLogic(""); - setMainAgentModel(null); // Clear model selection when creating new agent - setMainAgentModelId(null); // Clear model ID when creating new agent fetchSubAgentIdAndEnableToolList(t); } else { // In edit mode, data is loaded in handleEditAgent, here validate the form @@ -323,8 +325,31 @@ export default function AgentSetupOrchestrator({ setIsEditingAgent(false); setEditingAgent(null); setIsCreatingNewAgent(true); - // Note: Don't clear content here - let the parent component's useEffect handle restoration - // The parent component will restore cached content if available + + // Clear all content when creating new agent to avoid showing cached data + setBusinessLogic(""); + setDutyContent?.(""); + setConstraintContent?.(""); + setFewShotsContent?.(""); + setAgentName?.(""); + setAgentDescription?.(""); + setAgentDisplayName?.(""); + + // Clear tool and agent selections + setSelectedTools([]); + setEnabledToolIds([]); + setEnabledAgentIds([]); + + // Clear business logic model to allow default from global settings + // The useEffect in PromptManager will set it to the default from localStorage + setBusinessLogicModel(null); + setBusinessLogicModelId(null); + + // Clear main agent model selection to trigger default model selection + // The useEffect in AgentConfigModal will set it to the default from localStorage + setMainAgentModel(null); + setMainAgentModelId(null); + onEditingStateChange?.(false, null); }; @@ -417,7 +442,9 @@ export default function AgentSetupOrchestrator({ constraintContent, fewShotsContent, agentDisplayName, - mainAgentModelId ?? undefined + mainAgentModelId ?? undefined, + businessLogicModel ?? undefined, + businessLogicModelId ?? undefined ); } else { result = await updateAgent( @@ -433,7 +460,9 @@ export default function AgentSetupOrchestrator({ constraintContent, fewShotsContent, agentDisplayName, - mainAgentModelId ?? undefined + mainAgentModelId ?? undefined, + businessLogicModel ?? undefined, + businessLogicModelId ?? undefined ); } @@ -555,6 +584,8 @@ export default function AgentSetupOrchestrator({ setMainAgentModelId(agentDetail.model_id); setMainAgentMaxStep(agentDetail.max_step); setBusinessLogic(agentDetail.business_description || ""); + setBusinessLogicModel(agentDetail.business_logic_model_name || null); + setBusinessLogicModelId(agentDetail.business_logic_model_id || null); // Use backend returned sub_agent_id_list to set enabled agent list if ( @@ -595,21 +626,30 @@ export default function AgentSetupOrchestrator({ }; // Handle the update of the model + // Handle Business Logic Model change + const handleBusinessLogicModelChange = (value: string, modelId?: number) => { + setBusinessLogicModel(value); + if (modelId !== undefined) { + setBusinessLogicModelId(modelId); + } + }; + const handleModelChange = async (value: string, modelId?: number) => { const targetAgentId = isEditingAgent && editingAgent ? editingAgent.id : mainAgentId; - if (!targetAgentId) { - message.error(t("businessLogic.config.error.noAgentId")); - return; - } - // Update local state first setMainAgentModel(value); if (modelId !== undefined) { setMainAgentModelId(modelId); } + // If no agent ID yet (e.g., during initial creation setup), just update local state + // The model will be saved when the agent is fully created + if (!targetAgentId) { + return; + } + // Call updateAgent API to save the model change try { const result = await updateAgent( @@ -961,6 +1001,9 @@ export default function AgentSetupOrchestrator({ } onMaxStepChange={handleMaxStepChange} onBusinessLogicChange={(value: string) => setBusinessLogic(value)} + onBusinessLogicModelChange={handleBusinessLogicModelChange} + businessLogicModel={businessLogicModel} + businessLogicModelId={businessLogicModelId} onGenerateAgent={onGenerateAgent || (() => {})} onSaveAgent={handleSaveAgent} isGeneratingAgent={isGeneratingAgent} diff --git a/frontend/app/[locale]/setup/agents/components/PromptManager.tsx b/frontend/app/[locale]/setup/agents/components/PromptManager.tsx index 4bf08fb6..6ee33b55 100644 --- a/frontend/app/[locale]/setup/agents/components/PromptManager.tsx +++ b/frontend/app/[locale]/setup/agents/components/PromptManager.tsx @@ -2,7 +2,7 @@ import { useState, useRef, useEffect } from "react"; import { useTranslation } from "react-i18next"; -import { Modal, Badge, Input, App, Dropdown, Button } from "antd"; +import { Modal, Badge, Input, App, Button, Select } from "antd"; import { ThunderboltOutlined, LoadingOutlined, @@ -189,6 +189,10 @@ export interface PromptManagerProps { mainAgentModelId?: number | null; mainAgentMaxStep?: number; + // Business Logic Model (independent from main agent model) + businessLogicModel?: string | null; + businessLogicModelId?: number | null; + // Edit state isEditingMode?: boolean; isGeneratingAgent?: boolean; @@ -197,6 +201,7 @@ export interface PromptManagerProps { // Callback functions onBusinessLogicChange?: (content: string) => void; + onBusinessLogicModelChange?: (value: string, modelId?: number) => void; onDutyContentChange?: (content: string) => void; onConstraintContentChange?: (content: string) => void; onFewShotsContentChange?: (content: string) => void; @@ -233,11 +238,14 @@ export default function PromptManager({ mainAgentModel = "", mainAgentModelId = null, mainAgentMaxStep = 5, + businessLogicModel = null, + businessLogicModelId = null, isEditingMode = false, isGeneratingAgent = false, isCreatingNewAgent = false, canSaveAgent = false, onBusinessLogicChange, + onBusinessLogicModelChange, onDutyContentChange, onConstraintContentChange, onFewShotsContentChange, @@ -255,6 +263,7 @@ export default function PromptManager({ getButtonTitle, editingAgent, onModelSelect, + selectedGenerateModel, }: PromptManagerProps) { const { t } = useTranslation("common"); const { message } = App.useApp(); @@ -266,7 +275,20 @@ export default function PromptManager({ // Model selection states const [availableModels, setAvailableModels] = useState([]); const [loadingModels, setLoadingModels] = useState(false); - const [showModelDropdown, setShowModelDropdown] = useState(false); + // Fallback internal selection when parent does not control selection + const [internalSelectedModel, setInternalSelectedModel] = useState< + ModelOption | null + >(selectedGenerateModel ?? null); + + // Keep internal state in sync when parent-controlled value changes + useEffect(() => { + if (selectedGenerateModel && selectedGenerateModel?.id !== internalSelectedModel?.id) { + setInternalSelectedModel(selectedGenerateModel); + } + if (!selectedGenerateModel && internalSelectedModel) { + // Parent cleared selection; keep internal unless explicitly needed + } + }, [selectedGenerateModel]); // Load available models on component mount useEffect(() => { @@ -286,37 +308,122 @@ export default function PromptManager({ } }; - // Handle model selection and auto-generate - const handleModelSelect = (model: ModelOption) => { - onModelSelect?.(model); - setShowModelDropdown(false); + // Ensure a separate Business Logic LLM default selection using global default on creation + // IMPORTANT: Only read from localStorage when creating a NEW agent, not when editing existing agent + useEffect(() => { + if (!isCreatingNewAgent) return; // Only apply to new agents + if (!availableModels || availableModels.length === 0) return; + if (businessLogicModelId) return; // Already set - // Auto-trigger generation after model selection - if (onGenerateAgent) { - onGenerateAgent(model); + try { + const storedModelConfig = localStorage.getItem("model"); + const parsed = storedModelConfig ? JSON.parse(storedModelConfig) : null; + const defaultDisplayName = parsed?.llm?.displayName || ""; + const defaultModelName = parsed?.llm?.modelName || ""; + + let target = null as ModelOption | null; + if (defaultDisplayName) { + target = availableModels.find((m) => m.displayName === defaultDisplayName) || null; + } + if (!target && defaultModelName) { + target = availableModels.find((m) => m.name === defaultModelName) || null; + } + if (!target) { + target = availableModels[0] || null; + } + if (target && onBusinessLogicModelChange) { + onBusinessLogicModelChange(target.displayName, target.id); + } else if (target) { + if (onModelSelect) { + onModelSelect(target); + } else { + setInternalSelectedModel(target); + } + } + } catch (_e) { + // ignore parse errors + } + }, [isCreatingNewAgent, availableModels, businessLogicModelId, onBusinessLogicModelChange, onModelSelect]); + + // When editing an existing agent, load previously selected business logic model + useEffect(() => { + if (isCreatingNewAgent) return; + if (!availableModels || availableModels.length === 0) return; + if (selectedGenerateModel) return; // already set by parent/user + + let target: ModelOption | null = null; + if (businessLogicModelId) { + target = availableModels.find((m) => m.id === businessLogicModelId) || null; + } + if (!target && businessLogicModel) { + target = + availableModels.find((m) => m.displayName === businessLogicModel) || + availableModels.find((m) => m.name === businessLogicModel) || + null; + } + if (target) { + if (onModelSelect) { + onModelSelect(target); + } else { + setInternalSelectedModel(target); + } + } + }, [ + isCreatingNewAgent, + availableModels, + selectedGenerateModel, + businessLogicModelId, + businessLogicModel, + onModelSelect, + ]); + + // Handle model selection for prompt generation + const handleModelSelect = (modelId: number) => { + const model = availableModels.find((m) => m.id === modelId); + if (!model) return; + if (onBusinessLogicModelChange) { + onBusinessLogicModelChange(model.displayName, model.id); + } else if (onModelSelect) { + onModelSelect(model); + } else { + setInternalSelectedModel(model); } }; - // Handle generate button click - show model dropdown + // Handle generate button click const handleGenerateClick = () => { if (availableModels.length === 0) { message.warning(t("businessLogic.config.error.noAvailableModels")); return; } - setShowModelDropdown(true); + // Check if a model is selected: priority order is businessLogicModelId, selectedGenerateModel, internalSelectedModel + let chosen: ModelOption | null = null; + if (businessLogicModelId) { + chosen = availableModels.find((m) => m.id === businessLogicModelId) || null; + } + if (!chosen && selectedGenerateModel) { + chosen = selectedGenerateModel; + } + if (!chosen && internalSelectedModel) { + chosen = internalSelectedModel; + } + + if (!chosen) { + message.warning(t("businessLogic.config.modelPlaceholder")); + return; + } + if (onGenerateAgent) { + onGenerateAgent(chosen); + } }; - // Create dropdown items with disabled state for unavailable models - const modelDropdownItems = availableModels.map((model) => { - const isAvailable = model.connect_status === 'available'; - return { - key: model.id, - label: model.displayName || model.name, - disabled: !isAvailable, - onClick: () => handleModelSelect(model), - }; - }); + // Select options for available models + const modelSelectOptions = availableModels.map((model) => ({ + value: model.id, + label: model.displayName || model.name, + disabled: model.connect_status !== "available", + })); // Handle expand edit const handleExpandCard = (index: number) => { @@ -425,35 +532,47 @@ export default function PromptManager({
- {/* Main content */} -
- {/* Business logic description section */} -
-
-

- {t("businessLogic.title")} -

-
-
- onBusinessLogicChange?.(e.target.value)} - placeholder={t("businessLogic.placeholder")} - className="w-full resize-none p-3 text-sm transition-all duration-300 system-prompt-business-logic" - style={{ - minHeight: "120px", - maxHeight: "200px", - paddingRight: "12px", - paddingBottom: "40px", // Reserve space for button - }} - autoSize={{ - minRows: 3, - maxRows: 5, - }} - disabled={!isEditingMode} - /> - {/* Generate button */} -
+ {/* Main content */} +
+ {/* Business logic description section */} +
+
+

+ {t("businessLogic.title")} +

+
+
+ onBusinessLogicChange?.(e.target.value)} + placeholder={t("businessLogic.placeholder")} + className="w-full resize-none p-3 text-sm transition-all duration-300 system-prompt-business-logic" + style={{ + minHeight: "120px", + maxHeight: "200px", + paddingRight: "12px", + paddingBottom: "40px", // Reserve space for button + }} + autoSize={{ + minRows: 3, + maxRows: 5, + }} + disabled={!isEditingMode} + /> + {/* Generate button */} +
+
+ {t("businessLogic.config.model")}: + { const modelId = option && 'key' in option ? Number(option.key) : undefined; + setLocalMainAgentModel(value); onModelChange?.(value, modelId); }} size="large" - disabled={!isEditingMode} + disabled={false} style={{ width: "100%" }} placeholder={t("businessLogic.config.modelPlaceholder")} > diff --git a/frontend/app/[locale]/setup/agents/components/tool/ToolConfigModal.tsx b/frontend/app/[locale]/setup/agents/components/tool/ToolConfigModal.tsx index 0c9d93cd..eb28901c 100644 --- a/frontend/app/[locale]/setup/agents/components/tool/ToolConfigModal.tsx +++ b/frontend/app/[locale]/setup/agents/components/tool/ToolConfigModal.tsx @@ -14,7 +14,11 @@ import { Typography, Tooltip, } from "antd"; -import { CloseOutlined } from "@ant-design/icons"; +import { + CloseOutlined, + SettingOutlined, + EditOutlined, +} from "@ant-design/icons"; import { TOOL_PARAM_TYPES } from "@/const/agentConfig"; import { ToolParam, ToolConfigModalProps } from "@/types/agentConfig"; @@ -30,6 +34,7 @@ import { } from "@/services/agentConfigService"; import log from "@/lib/logger"; import { useModalPosition } from "@/hooks/useModalPosition"; +import { DEFAULT_TYPE } from "@/const/constants"; export default function ToolConfigModal({ isOpen, @@ -52,6 +57,9 @@ export default function ToolConfigModal({ const [parsedInputs, setParsedInputs] = useState>({}); const [paramValues, setParamValues] = useState>({}); const [dynamicInputParams, setDynamicInputParams] = useState([]); + const [isManualInputMode, setIsManualInputMode] = useState(false); + const [manualJsonInput, setManualJsonInput] = useState(""); + const [isParseSuccessful, setIsParseSuccessful] = useState(false); const { windowWidth, mainModalTop, mainModalRight } = useModalPosition(isOpen); @@ -226,45 +234,64 @@ export default function ToolConfigModal({ const parsedInputs = parseToolInputs(tool.inputs || ""); const paramNames = extractParameterNames(parsedInputs); - setParsedInputs(parsedInputs); - setDynamicInputParams(paramNames); - - // Initialize parameter values with appropriate defaults based on type - const initialValues: Record = {}; - paramNames.forEach((paramName) => { - const paramInfo = parsedInputs[paramName]; - const paramType = paramInfo?.type || "string"; - - if ( - paramInfo && - typeof paramInfo === "object" && - paramInfo.default != null - ) { - // Use provided default value, convert to string for UI display - switch (paramType) { - case "boolean": - initialValues[paramName] = paramInfo.default ? "true" : "false"; - break; - case "array": - case "object": - // JSON.stringify with indentation of 2 spaces for better readability - initialValues[paramName] = JSON.stringify( - paramInfo.default, - null, - 2 - ); - break; - default: - initialValues[paramName] = String(paramInfo.default); + // Check if parsing was successful (not empty object) + const isSuccessful = Object.keys(parsedInputs).length > 0; + setIsParseSuccessful(isSuccessful); + if (isSuccessful) { + setParsedInputs(parsedInputs); + setDynamicInputParams(paramNames); + + // Initialize parameter values with appropriate defaults based on type + const initialValues: Record = {}; + paramNames.forEach((paramName) => { + const paramInfo = parsedInputs[paramName]; + const paramType = paramInfo?.type || DEFAULT_TYPE; + + if ( + paramInfo && + typeof paramInfo === "object" && + paramInfo.default != null + ) { + // Use provided default value, convert to string for UI display + switch (paramType) { + case "boolean": + initialValues[paramName] = paramInfo.default ? "true" : "false"; + break; + case "array": + case "object": + // JSON.stringify with indentation of 2 spaces for better readability + initialValues[paramName] = JSON.stringify( + paramInfo.default, + null, + 2 + ); + break; + default: + initialValues[paramName] = String(paramInfo.default); + } } - } - }); - setParamValues(initialValues); + }); + setParamValues(initialValues); + // Reset to parsed mode when parsing succeeds + setIsManualInputMode(false); + setManualJsonInput(""); + } else { + // Parsing returned empty object, treat as failed + setParsedInputs({}); + setParamValues({}); + setDynamicInputParams([]); + setIsManualInputMode(true); + setManualJsonInput("{}"); + } } catch (error) { log.error("Parameter parsing error:", error); setParsedInputs({}); setParamValues({}); setDynamicInputParams([]); + setIsParseSuccessful(false); + // When parsing fails, automatically switch to manual input mode + setIsManualInputMode(true); + setManualJsonInput("{}"); } setTestPanelVisible(true); @@ -278,6 +305,9 @@ export default function ToolConfigModal({ setParamValues({}); setDynamicInputParams([]); setTestExecuting(false); + setIsManualInputMode(false); + setManualJsonInput(""); + setIsParseSuccessful(false); }; // Execute tool test @@ -292,7 +322,7 @@ export default function ToolConfigModal({ dynamicInputParams.forEach((paramName) => { const value = paramValues[paramName]; const paramInfo = parsedInputs[paramName]; - const paramType = paramInfo?.type || "string"; + const paramType = paramInfo?.type || DEFAULT_TYPE; if (value && value.trim() !== "") { // Convert value to correct type based on parameter type from inputs @@ -674,60 +704,144 @@ export default function ToolConfigModal({ )} - {/* Dynamic input parameters from tool inputs */} - {dynamicInputParams.length > 0 && ( + {/* Input parameters section with conditional toggle */} + {(dynamicInputParams.length > 0 || isManualInputMode) && ( <> - - {t("toolConfig.toolTest.inputParams")} -
- {dynamicInputParams.map((paramName) => { - const paramInfo = parsedInputs[paramName]; - const description = - paramInfo && - typeof paramInfo === "object" && - paramInfo.description - ? paramInfo.description - : paramName; - - return ( -
- {paramName} - - { - setParamValues((prev) => ({ - ...prev, - [paramName]: e.target.value, - })); - }} - style={{ flex: 1 }} - /> - -
- ); - })} + {t("toolConfig.toolTest.inputParams")} + {/* Only show toggle button if parsing was successful */} + {isParseSuccessful && ( + + )}
+ + {isManualInputMode ? ( + // Manual JSON input mode +
+ setManualJsonInput(e.target.value)} + rows={6} + style={{ fontFamily: "monospace" }} + /> +
+ ) : ( + // Parsed parameters mode + dynamicInputParams.length > 0 && ( +
+ {dynamicInputParams.map((paramName) => { + const paramInfo = parsedInputs[paramName]; + const description = + paramInfo && + typeof paramInfo === "object" && + paramInfo.description + ? paramInfo.description + : paramName; + + return ( +
+ + {paramName} + + + { + setParamValues((prev) => ({ + ...prev, + [paramName]: e.target.value, + })); + }} + style={{ flex: 1 }} + /> + +
+ ); + })} +
+ ) + )} )} diff --git a/frontend/app/[locale]/setup/agents/config.tsx b/frontend/app/[locale]/setup/agents/config.tsx index 6286cd32..bd56c685 100644 --- a/frontend/app/[locale]/setup/agents/config.tsx +++ b/frontend/app/[locale]/setup/agents/config.tsx @@ -49,6 +49,8 @@ export default function AgentConfig() { const [mainAgentModel, setMainAgentModel] = useState(null); const [mainAgentModelId, setMainAgentModelId] = useState(null); const [mainAgentMaxStep, setMainAgentMaxStep] = useState(5); + const [businessLogicModel, setBusinessLogicModel] = useState(null); + const [businessLogicModelId, setBusinessLogicModelId] = useState(null); const [tools, setTools] = useState([]); const [mainAgentId, setMainAgentId] = useState(null); const [subAgentList, setSubAgentList] = useState([]); @@ -439,6 +441,10 @@ export default function AgentConfig() { setMainAgentModelId={setMainAgentModelId} mainAgentMaxStep={mainAgentMaxStep} setMainAgentMaxStep={setMainAgentMaxStep} + businessLogicModel={businessLogicModel} + setBusinessLogicModel={setBusinessLogicModel} + businessLogicModelId={businessLogicModelId} + setBusinessLogicModelId={setBusinessLogicModelId} tools={tools} subAgentList={subAgentList} loadingAgents={loadingAgents} diff --git a/frontend/app/[locale]/setup/knowledges/components/document/DocumentList.tsx b/frontend/app/[locale]/setup/knowledges/components/document/DocumentList.tsx index 8296c815..f79616e5 100644 --- a/frontend/app/[locale]/setup/knowledges/components/document/DocumentList.tsx +++ b/frontend/app/[locale]/setup/knowledges/components/document/DocumentList.tsx @@ -9,6 +9,7 @@ import { useTranslation } from "react-i18next"; import { Input, Button, App, Select } from "antd"; import { InfoCircleFilled } from "@ant-design/icons"; +import { MarkdownRenderer } from "@/components/ui/markdownRenderer"; import { UI_CONFIG, @@ -23,6 +24,7 @@ import { Document } from "@/types/knowledgeBase"; import { ModelOption } from "@/types/modelConfig"; import { formatFileSize, sortByStatusAndDate } from "@/lib/utils"; import log from "@/lib/logger"; +import { useConfig } from "@/hooks/useConfig"; import DocumentStatus from "./DocumentStatus"; import UploadArea from "../upload/UploadArea"; @@ -85,6 +87,7 @@ const DocumentListContainer = forwardRef( const { message } = App.useApp(); const uploadAreaRef = useRef(null); const { state: docState } = useDocumentContext(); + const { modelConfig } = useConfig(); // Use fixed height instead of percentage const titleBarHeight = UI_CONFIG.TITLE_BAR_HEIGHT; @@ -148,9 +151,57 @@ const DocumentListContainer = forwardRef( try { const models = await modelService.getLLMModels(); setAvailableModels(models); - // Set first available model as default - if (models.length > 0) { - setSelectedModel(models[0].id); + + // Determine initial selection order: + // 1) Knowledge base's own configured model (server-side config) + // 2) Globally configured default LLM from quick setup (create mode or no KB model) + // 3) First available model + + let initialModelId: number | null = null; + + // 1) Knowledge base model (if provided) + if (knowledgeBaseModel) { + const matchedByName = models.find((m) => m.name === knowledgeBaseModel); + const matchedByDisplay = matchedByName + ? null + : models.find((m) => m.displayName === knowledgeBaseModel); + if (matchedByName) { + initialModelId = matchedByName.id; + } else if (matchedByDisplay) { + initialModelId = matchedByDisplay.id; + } + } + + // 2) Fallback to globally configured default LLM + if (initialModelId === null) { + const configuredDisplayName = modelConfig?.llm?.displayName || ""; + const configuredModelName = modelConfig?.llm?.modelName || ""; + + const matchedByDisplay = models.find( + (m) => m.displayName === configuredDisplayName && configuredDisplayName !== "" + ); + const matchedByName = matchedByDisplay + ? null + : models.find( + (m) => m.name === configuredModelName && configuredModelName !== "" + ); + + if (matchedByDisplay) { + initialModelId = matchedByDisplay.id; + } else if (matchedByName) { + initialModelId = matchedByName.id; + } + } + + // 3) Final fallback to first available model + if (initialModelId === null) { + if (models.length > 0) { + initialModelId = models[0].id; + } + } + + if (initialModelId !== null) { + setSelectedModel(initialModelId); } else { message.warning(t("businessLogic.config.error.noAvailableModels")); } @@ -362,11 +413,11 @@ const DocumentListContainer = forwardRef(
- setSummary(e.target.value)} - className="flex-1 min-h-0 mb-5 resize-none text-lg leading-[1.7] p-5" - /> +
+
+ +
+
+ +
+ + {showFormatMenu && ( +
{ + e.stopPropagation(); + }} + > +
+ + +
+
+ )} +
+ + )} + +
+ )} + + {/* Content area */} + {diagramState.showCode ? ( +
+
+            {code}
+          
+
+ ) : ( + <> + {!result || !("dataUrl" in result) ? ( +
+
+
+ ) : ( +
1 + ? isDragging + ? "grabbing" + : "grab" + : "default", + }} + onMouseDown={handleMouseDown} + onMouseMove={handleMouseMove} + onMouseUp={handleMouseUp} + onMouseLeave={handleMouseLeave} + onKeyDown={handleKeyDown} + tabIndex={diagramState.zoomLevel > 1 ? 0 : -1} + > + {ariaLabel { + const img = e.target as HTMLImageElement; + const aspectRatio = img.naturalWidth / img.naturalHeight; + const isWide = aspectRatio > 1.5; // Aspect ratio > 1.5 is considered a wide chart + + setIsWideDiagram(isWide); + }} + /> +
+ )} + + )} +
+ ); +} + +// Memoize the component to prevent unnecessary re-renders +export const Diagram = React.memo(DiagramComponent); diff --git a/frontend/components/ui/markdownRenderer.tsx b/frontend/components/ui/markdownRenderer.tsx index cda50fba..ef3aa5d8 100644 --- a/frontend/components/ui/markdownRenderer.tsx +++ b/frontend/components/ui/markdownRenderer.tsx @@ -1,3 +1,5 @@ +"use client"; + import React from "react"; import { useTranslation } from "react-i18next"; import ReactMarkdown from "react-markdown"; @@ -19,11 +21,13 @@ import { TooltipTrigger, } from "@/components/ui/tooltip"; import { CopyButton } from "@/components/ui/copyButton"; +import { Diagram } from "@/components/ui/Diagram"; interface MarkdownRendererProps { content: string; className?: string; searchResults?: SearchResult[]; + showDiagramToggle?: boolean; } // Get background color for different tool signs @@ -350,6 +354,7 @@ export const MarkdownRenderer: React.FC = ({ content, className, searchResults = [], + showDiagramToggle = true, }) => { const { t } = useTranslation("common"); @@ -397,7 +402,7 @@ export const MarkdownRenderer: React.FC = ({ const processText = (text: string) => { if (typeof text !== "string") return text; - const parts = text.split(/(\[\[[^\]]+\]\])/g); + const parts = text.split(/(\[\[[^\]]+\]\]|:mermaid\[[^\]]+\])/g); return ( <> {parts.map((part, index) => { @@ -426,6 +431,21 @@ export const MarkdownRenderer: React.FC = ({ return ""; } } + // Inline Mermaid using :mermaid[graph LR; A-->B] - removed inline support + const mmd = part.match(/^:mermaid\[([^\]]+)\]$/); + if (mmd) { + const code = mmd[1]; + return ; + } + // Handle line breaks in text content + if (part.includes('\n')) { + return part.split('\n').map((line, lineIndex) => ( + + {line} + {lineIndex < part.split('\n').length - 1 &&
} +
+ )); + } return part; })} @@ -540,6 +560,10 @@ export const MarkdownRenderer: React.FC = ({ {children}

), + // Horizontal rule + hr: () => ( +
+ ), // List item li: ({ children }: any) => (
  • @@ -595,38 +619,44 @@ export const MarkdownRenderer: React.FC = ({ ? children.join("") : children ?? ""; const codeContent = String(raw).replace(/^\n+|\n+$/g, ""); - if (!inline && match && match[1]) { - return ( -
    -
    - - {match[1]} - - -
    -
    - - {codeContent} - + if (match && match[1]) { + // Check if it's a Mermaid diagram + if (match[1] === "mermaid") { + return ; + } + if (!inline) { + return ( +
    +
    + + {match[1]} + + +
    +
    + + {codeContent} + +
    -
    - ); + ); + } } } catch (error) { // Handle error silently diff --git a/frontend/const/constants.ts b/frontend/const/constants.ts index fa7fd964..e47699f4 100644 --- a/frontend/const/constants.ts +++ b/frontend/const/constants.ts @@ -9,3 +9,6 @@ export const TOKEN_REFRESH_CD = 1 * 60 * 1000; export const isProduction = process.env.NODE_ENV === "production"; export const APP_VERSION = "v1.0.0"; + +// Default parameter type constant +export const DEFAULT_TYPE = "string"; diff --git a/frontend/package.json b/frontend/package.json index 827531e1..d818cf45 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -60,6 +60,7 @@ "input-otp": "1.4.1", "katex": "^0.16.11", "lucide-react": "^0.454.0", + "mermaid": "^11.12.0", "next": "15.4.5", "next-i18next": "^15.4.2", "next-themes": "^0.4.4", diff --git a/frontend/postcss.config.mjs b/frontend/postcss.config.mjs index 1a69fd2a..2ef30fcf 100644 --- a/frontend/postcss.config.mjs +++ b/frontend/postcss.config.mjs @@ -2,6 +2,7 @@ const config = { plugins: { tailwindcss: {}, + autoprefixer: {}, }, }; diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json index 327f0580..495ded4e 100644 --- a/frontend/public/locales/en/common.json +++ b/frontend/public/locales/en/common.json @@ -366,7 +366,8 @@ "toolConfig.toolTest.execute": "Execute Test", "toolConfig.toolTest.result": "Test Result", "toolConfig.button.testTool": "Test Tool", - + "toolConfig.toolTest.manualInput": "Manual Input", + "toolConfig.toolTest.parseMode": "Parse Mode", "toolPool.title": "Select tools", "toolPool.loading": "Loading...", "toolPool.loadingTools": "Loading tools...", @@ -938,5 +939,15 @@ "businessLogic.config.error.loadModelsFailed": "Failed to load available models", "businessLogic.config.error.noAvailableModels": "No available models", "businessLogic.config.error.modelUpdateFailed": "Failed to update agent model", - "businessLogic.config.error.maxStepsUpdateFailed": "Failed to update agent max steps" + "businessLogic.config.error.maxStepsUpdateFailed": "Failed to update agent max steps", + + "diagram.button.showDiagram": "Show Diagram", + "diagram.button.showCode": "Show Code", + "diagram.button.zoomOut": "Zoom Out", + "diagram.button.zoomIn": "Zoom In", + "diagram.button.download": "Download", + "diagram.format.svg": "SVG", + "diagram.format.png": "PNG", + "diagram.format.selectFormat": "Select Format", + "diagram.error.renderFailed": "Render Failed" } diff --git a/frontend/public/locales/zh/common.json b/frontend/public/locales/zh/common.json index 9696afcb..d64786fa 100644 --- a/frontend/public/locales/zh/common.json +++ b/frontend/public/locales/zh/common.json @@ -367,7 +367,8 @@ "toolConfig.toolTest.execute": "执行测试", "toolConfig.toolTest.result": "测试结果", "toolConfig.button.testTool": "工具测试", - + "toolConfig.toolTest.manualInput": "手动输入", + "toolConfig.toolTest.parseMode": "解析模式", "toolPool.title": "选择 Agent 的工具", "toolPool.loading": "加载中...", "toolPool.loadingTools": "加载工具中...", @@ -938,5 +939,15 @@ "businessLogic.config.error.loadModelsFailed": "加载可用模型失败", "businessLogic.config.error.noAvailableModels": "暂无可用的模型", "businessLogic.config.error.modelUpdateFailed": "更新Agent模型失败", - "businessLogic.config.error.maxStepsUpdateFailed": "更新Agent最大步数失败" + "businessLogic.config.error.maxStepsUpdateFailed": "更新Agent最大步数失败", + + "diagram.button.showDiagram": "显示图表", + "diagram.button.showCode": "显示代码", + "diagram.button.zoomOut": "缩小", + "diagram.button.zoomIn": "放大", + "diagram.button.download": "下载", + "diagram.format.svg": "SVG", + "diagram.format.png": "PNG", + "diagram.format.selectFormat": "选择格式", + "diagram.error.renderFailed": "渲染失败" } diff --git a/frontend/services/agentConfigService.ts b/frontend/services/agentConfigService.ts index 8d06a0fe..d48073df 100644 --- a/frontend/services/agentConfigService.ts +++ b/frontend/services/agentConfigService.ts @@ -321,7 +321,9 @@ export const updateAgent = async ( constraintPrompt?: string, fewShotsPrompt?: string, displayName?: string, - modelId?: number + modelId?: number, + businessLogicModelName?: string, + businessLogicModelId?: number ) => { try { const response = await fetch(API_ENDPOINTS.agent.update, { @@ -341,6 +343,8 @@ export const updateAgent = async ( duty_prompt: dutyPrompt, constraint_prompt: constraintPrompt, few_shots_prompt: fewShotsPrompt, + business_logic_model_name: businessLogicModelName, + business_logic_model_id: businessLogicModelId, }), }); @@ -504,6 +508,8 @@ export const searchAgentInfo = async (agentId: number) => { constraint_prompt: data.constraint_prompt, few_shots_prompt: data.few_shots_prompt, business_description: data.business_description, + business_logic_model_name: data.business_logic_model_name, + business_logic_model_id: data.business_logic_model_id, provide_run_summary: data.provide_run_summary, enabled: data.enabled, is_available: data.is_available, diff --git a/frontend/styles/react-markdown.css b/frontend/styles/react-markdown.css index 9b61f6d5..3a30eeb3 100644 --- a/frontend/styles/react-markdown.css +++ b/frontend/styles/react-markdown.css @@ -101,6 +101,15 @@ color: var(--color-nord0); } +/* Horizontal Rule Styles */ +.markdown-hr { + border: none; + height: 2px; + background-color: #e5e7eb; + margin: 1.5rem 0; + border-radius: 1px; +} + /* List Item Styles */ .markdown-li { margin-bottom: 0.25rem; @@ -428,3 +437,180 @@ .code-block-content pre::-webkit-scrollbar-thumb:hover { background: #aaa; } + +/* Mermaid Diagram Styles */ +.mermaid-container { + border: 1px solid #e5e7eb; + border-radius: 0.5rem; + overflow: hidden; + background-color: #ffffff; + box-shadow: 0 1px 3px 0 rgba(0, 0, 0, 0.1), 0 1px 2px 0 rgba(0, 0, 0, 0.06); + margin: 1rem 0; + transition: box-shadow 0.2s ease; +} + +.mermaid-container:hover { + box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06); +} + +.mermaid-header { + display: flex; + align-items: center; + justify-content: space-between; + padding: 0.5rem 1rem; + background: linear-gradient(to right, #f9fafb, #f3f4f6); + border-bottom: 1px solid #e5e7eb; +} + +.mermaid-label { + font-size: 0.875rem; + font-weight: 500; + color: #374151; +} + +.mermaid-copy-button { + transition: background-color 0.2s; + border-radius: 0.375rem; + padding: 0.25rem; +} + +.mermaid-copy-button:hover { + background-color: #e5e7eb; +} + +.mermaid-content { + position: relative; +} + +.mermaid-diagram { + display: flex; + justify-content: center; + align-items: center; + padding: 1rem; + min-height: 120px; + overflow: visible; /* allow container to grow with content */ + background: linear-gradient(135deg, #ffffff 0%, #f8fafc 100%); +} + +.mermaid-diagram svg { + width: 100%; + max-width: 100%; + height: auto; + display: block; /* ensure responsive SVG sizing */ + filter: drop-shadow(0 2px 4px rgba(0, 0, 0, 0.05)); +} +.mermaid-inline { + display: inline-block; + vertical-align: middle; + line-height: 1; +} + +.mermaid-inline-svg svg { + height: 1.25em; + width: auto; + display: inline-block; + vertical-align: middle; +} + +/* Generic Diagram helpers for new Diagram component */ +.diagram-block { + display: block; + width: 100%; +} + +.diagram-inline { + display: inline-block; + vertical-align: baseline; + line-height: 1; +} + +.mermaid-code-display { + padding: 1rem; + background-color: #f9fafb; + font-size: 0.875rem; + font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; + overflow-x: auto; + white-space: pre-wrap; +} + +.mermaid-loading { + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + padding: 2rem; + color: #6b7280; +} + +.mermaid-error-container { + border: 1px solid #fecaca; + border-radius: 0.5rem; + overflow: hidden; + background-color: #fef2f2; + margin: 1rem 0; +} + +.mermaid-error-header { + display: flex; + align-items: center; + justify-content: space-between; + padding: 0.5rem 1rem; + background-color: #fee2e2; + border-bottom: 1px solid #fecaca; +} + +.mermaid-error-label { + font-size: 0.875rem; + font-weight: 500; + color: #b91c1c; +} + +.mermaid-error-content { + padding: 1rem; +} + +.mermaid-error-message { + color: #dc2626; +} + +/* Gantt chart optimization styles */ +.mermaid svg { + /* Ensure Gantt chart has enough space */ + min-width: 100%; + overflow: visible; +} + +/* Gantt chart timeline label optimization */ +.mermaid svg .axis text { + font-size: 11px !important; + font-weight: 500 !important; + fill: #6b7280 !important; + text-anchor: middle !important; + dominant-baseline: hanging !important; +} + +/* Gantt chart grid line optimization */ +.mermaid svg .grid .tick line { + stroke: #e5e7eb !important; + stroke-width: 1px !important; +} + +/* Gantt chart task bar optimization */ +.mermaid svg .task text { + font-size: 12px !important; + font-weight: 500 !important; + fill: #374151 !important; +} + +/* Gantt chart section title optimization */ +.mermaid svg .section text { + font-size: 14px !important; + font-weight: 600 !important; + fill: #374151 !important; +} + +/* Ensure Gantt chart container has sufficient padding */ +.mermaid { + padding: 20px !important; + margin: 10px 0 !important; +} \ No newline at end of file diff --git a/frontend/types/agentConfig.ts b/frontend/types/agentConfig.ts index b7134d86..3d926a05 100644 --- a/frontend/types/agentConfig.ts +++ b/frontend/types/agentConfig.ts @@ -19,6 +19,8 @@ export interface Agent { constraint_prompt?: string; few_shots_prompt?: string; business_description?: string; + business_logic_model_name?: string; + business_logic_model_id?: number; is_available?: boolean; sub_agent_id_list?: number[]; } @@ -95,6 +97,10 @@ export interface AgentSetupOrchestratorProps { setMainAgentModelId: (value: number | null) => void; mainAgentMaxStep: number; setMainAgentMaxStep: (value: number) => void; + businessLogicModel: string | null; + setBusinessLogicModel: (value: string | null) => void; + businessLogicModelId: number | null; + setBusinessLogicModelId: (value: number | null) => void; tools: Tool[]; subAgentList?: Agent[]; loadingAgents?: boolean; diff --git a/sdk/nexent/vector_database/elasticsearch_core.py b/sdk/nexent/vector_database/elasticsearch_core.py index a908935c..5cd4b27f 100644 --- a/sdk/nexent/vector_database/elasticsearch_core.py +++ b/sdk/nexent/vector_database/elasticsearch_core.py @@ -1,873 +1,903 @@ -import time -import logging -import threading -from typing import List, Dict, Any, Optional -from contextlib import contextmanager -from dataclasses import dataclass -from datetime import datetime, timedelta -from ..core.models.embedding_model import BaseEmbedding -from .utils import format_size, format_timestamp, build_weighted_query -from elasticsearch import Elasticsearch, exceptions - -from ..core.nlp.tokenizer import calculate_term_weights - -logger = logging.getLogger("elasticsearch_core") - -@dataclass -class BulkOperation: - """Bulk operation status tracking""" - index_name: str - operation_id: str - start_time: datetime - expected_duration: timedelta - -class ElasticSearchCore: - """ - Core class for Elasticsearch operations including: - - Index management - - Document insertion with embeddings - - Document deletion - - Accurate text search - - Semantic vector search - - Hybrid search - - Index statistics - """ - - def __init__( - self, - host: Optional[str], - api_key: Optional[str], - verify_certs: bool = False, - ssl_show_warn: bool = False, - ): - """ - Initialize ElasticSearchCore with Elasticsearch client and JinaEmbedding model. - - Args: - host: Elasticsearch host URL (defaults to env variable) - api_key: Elasticsearch API key (defaults to env variable) - verify_certs: Whether to verify SSL certificates - ssl_show_warn: Whether to show SSL warnings - """ - # Get credentials from environment if not provided - self.host = host - self.api_key = api_key - - # Initialize Elasticsearch client with HTTPS support - self.client = Elasticsearch( - self.host, - api_key=self.api_key, - verify_certs=verify_certs, - ssl_show_warn=ssl_show_warn, - request_timeout=20, - max_retries=3, # Reduce retries for faster failure detection - retry_on_timeout=True, - retry_on_status=[502, 503, 504], # Retry on these status codes, - ) - - # Initialize embedding model - self._bulk_operations: Dict[str, List[BulkOperation]] = {} - self._settings_lock = threading.Lock() - self._operation_counter = 0 - - # Embedding API limits - self.max_texts_per_batch = 2048 - self.max_tokens_per_text = 8192 - self.max_total_tokens = 100000 - - # ---- INDEX MANAGEMENT ---- - - def create_vector_index(self, index_name: str, embedding_dim: Optional[int] = None) -> bool: - """ - Create a new vector search index with appropriate mappings in a celery-friendly way. - - Args: - index_name: Name of the index to create - embedding_dim: Dimension of the embedding vectors (optional, will use model's dim if not provided) - - Returns: - bool: True if creation was successful - """ - try: - # Use provided embedding_dim or get from model - actual_embedding_dim = embedding_dim or 1024 - - # Use balanced fixed settings to avoid dynamic adjustment - settings = { - "number_of_shards": 1, - "number_of_replicas": 0, - "refresh_interval": "5s", # not too fast, not too slow - "index": { - "max_result_window": 50000, - "translog": { - "durability": "async", - "sync_interval": "5s" - }, - "write": { - "wait_for_active_shards": "1" - }, - # Memory optimization for bulk operations - "merge": { - "policy": { - "max_merge_at_once": 5, - "segments_per_tier": 5 - } - } - } - } - - # Check if index already exists - if self.client.indices.exists(index=index_name): - logger.info(f"Index {index_name} already exists, skipping creation") - self._ensure_index_ready(index_name) - return True - - # Define the mapping with vector field - mappings = { - "properties": { - "id": {"type": "keyword"}, - "title": {"type": "text"}, - "filename": {"type": "keyword"}, - "path_or_url": {"type": "keyword"}, - "language": {"type": "keyword"}, - "author": {"type": "keyword"}, - "date": {"type": "date"}, - "content": {"type": "text"}, - "process_source": {"type": "keyword"}, - "embedding_model_name": {"type": "keyword"}, - "file_size": {"type": "long"}, - "create_time": {"type": "date"}, - "embedding": { - "type": "dense_vector", - "dims": actual_embedding_dim, - "index": "true", - "similarity": "cosine", - }, - } - } - - # Create the index with the defined mappings - self.client.indices.create( - index=index_name, - mappings=mappings, - settings=settings, - wait_for_active_shards="1" - ) - - # Force refresh to ensure visibility - self._force_refresh_with_retry(index_name) - self._ensure_index_ready(index_name) - - logger.info(f"Successfully created index: {index_name}") - return True - - except exceptions.RequestError as e: - # Handle the case where index already exists (error 400) - if "resource_already_exists_exception" in str(e): - logger.info(f"Index {index_name} already exists, skipping creation") - self._ensure_index_ready(index_name) - return True - logger.error(f"Error creating index: {str(e)}") - return False - except Exception as e: - logger.error(f"Error creating index: {str(e)}") - return False - - def _force_refresh_with_retry(self, index_name: str, max_retries: int = 3) -> bool: - """ - Force refresh with retry - synchronous version - """ - for attempt in range(max_retries): - try: - self.client.indices.refresh(index=index_name) - return True - except Exception as e: - if attempt < max_retries - 1: - time.sleep(0.5 * (attempt + 1)) - continue - logger.error(f"Failed to refresh index {index_name}: {e}") - return False - return False - - def _ensure_index_ready(self, index_name: str, timeout: int = 10) -> bool: - """ - Ensure index is ready, avoid 503 error - synchronous version - """ - start_time = time.time() - - while time.time() - start_time < timeout: - try: - # Check cluster health - health = self.client.cluster.health( - index=index_name, - wait_for_status="yellow", - timeout="1s" - ) - - if health["status"] in ["green", "yellow"]: - # Double check: try simple query - self.client.search( - index=index_name, - body={"query": {"match_all": {}}, "size": 0} - ) - return True - - except Exception as e: - time.sleep(0.1) - - logger.warning(f"Index {index_name} may not be fully ready after {timeout}s") - return False - - @contextmanager - def bulk_operation_context(self, index_name: str, estimated_duration: int = 60): - """ - Celery-friendly context manager - using threading.Lock - """ - operation_id = f"bulk_{self._operation_counter}_{threading.current_thread().name}" - self._operation_counter += 1 - - operation = BulkOperation( - index_name=index_name, - operation_id=operation_id, - start_time=datetime.now(), - expected_duration=timedelta(seconds=estimated_duration) - ) - - with self._settings_lock: - # Record current operation - if index_name not in self._bulk_operations: - self._bulk_operations[index_name] = [] - self._bulk_operations[index_name].append(operation) - - # If this is the first bulk operation, adjust settings - if len(self._bulk_operations[index_name]) == 1: - self._apply_bulk_settings(index_name) - - try: - yield operation_id - finally: - with self._settings_lock: - # Remove operation record - self._bulk_operations[index_name] = [ - op for op in self._bulk_operations[index_name] - if op.operation_id != operation_id - ] - - # If there are no other bulk operations, restore settings - if not self._bulk_operations[index_name]: - self._restore_normal_settings(index_name) - del self._bulk_operations[index_name] - - def _apply_bulk_settings(self, index_name: str): - """Apply bulk operation optimization settings""" - try: - self.client.indices.put_settings( - index=index_name, - body={ - "refresh_interval": "30s", - "translog.durability": "async", - "translog.sync_interval": "10s" - } - ) - logger.info(f"Applied bulk settings to {index_name}") - except Exception as e: - logger.warning(f"Failed to apply bulk settings: {e}") - - def _restore_normal_settings(self, index_name: str): - """Restore normal settings""" - try: - self.client.indices.put_settings( - index=index_name, - body={ - "refresh_interval": "5s", - "translog.durability": "request" - } - ) - # Refresh after restoration - self._force_refresh_with_retry(index_name) - logger.info(f"Restored normal settings for {index_name}") - except Exception as e: - logger.warning(f"Failed to restore settings: {e}") - - def delete_index(self, index_name: str) -> bool: - """ - Delete an entire index - - Args: - index_name: Name of the index to delete - - Returns: - bool: True if deletion was successful - """ - try: - self.client.indices.delete(index=index_name) - logger.info(f"Successfully deleted the index: {index_name}") - return True - except exceptions.NotFoundError: - logger.info(f"Index {index_name} not found") - return False - except Exception as e: - logger.error(f"Error deleting index: {str(e)}") - return False - - def get_user_indices(self, index_pattern: str = "*") -> List[str]: - """ - Get list of user created indices (excluding system indices) - - Args: - index_pattern: Pattern to match index names - - Returns: - List of index names - """ - try: - indices = self.client.indices.get_alias(index=index_pattern) - # Filter out system indices (starting with '.') - return [index_name for index_name in indices.keys() if not index_name.startswith('.')] - except Exception as e: - logger.error(f"Error getting user indices: {str(e)}") - return [] - - # ---- DOCUMENT OPERATIONS ---- - - def index_documents( - self, - index_name: str, - embedding_model: BaseEmbedding, - documents: List[Dict[str, Any]], - batch_size: int = 2048, - content_field: str = "content" - ) -> int: - """ - Smart batch insertion - automatically selecting strategy based on data size - - Args: - index_name: Name of the index to add documents to - embedding_model: Model used to generate embeddings for documents - documents: List of document dictionaries - batch_size: Number of documents to process at once - content_field: Field to use for generating embeddings - - Returns: - int: Number of documents successfully indexed - """ - logger.info(f"Indexing {len(documents)} chunks to {index_name}") - - # Handle empty documents list - if not documents: - return 0 - - # Smart strategy selection - total_docs = len(documents) - if total_docs < 100: - # Small data: direct insertion, using wait_for refresh - return self._small_batch_insert(index_name, documents, content_field, embedding_model) - else: - # Large data: using context manager - estimated_duration = max(60, total_docs // 100) - with self.bulk_operation_context(index_name, estimated_duration): - return self._large_batch_insert(index_name, documents, batch_size, content_field, embedding_model) - - def _small_batch_insert(self, index_name: str, documents: List[Dict[str, Any]], content_field: str, embedding_model:BaseEmbedding) -> int: - """Small batch insertion: real-time""" - try: - # Preprocess documents - processed_docs = self._preprocess_documents(documents, content_field) - - # Get embeddings - inputs = [doc[content_field] for doc in processed_docs] - embeddings = embedding_model.get_embeddings(inputs) - - # Prepare bulk operations - operations = [] - for doc, embedding in zip(processed_docs, embeddings): - operations.append({"index": {"_index": index_name}}) - doc["embedding"] = embedding - if "embedding_model_name" not in doc: - doc["embedding_model_name"] = embedding_model.embedding_model_name - operations.append(doc) - - # Execute bulk insertion, wait for refresh to complete - response = self.client.bulk( - index=index_name, - operations=operations, - refresh='wait_for' - ) - - # Handle errors - self._handle_bulk_errors(response) - - logger.info(f"Small batch insert completed: {len(documents)} chunks indexed.") - return len(documents) - - except Exception as e: - logger.error(f"Small batch insert failed: {e}") - return 0 - - def _large_batch_insert(self, index_name: str, documents: List[Dict[str, Any]], batch_size: int, content_field: str, embedding_model: BaseEmbedding) -> int: - """ - Large batch insertion with sub-batching for embedding API. - Splits large document batches into smaller chunks to respect embedding API limits before bulk inserting into Elasticsearch. - """ - try: - processed_docs = self._preprocess_documents(documents, content_field) - total_indexed = 0 - total_docs = len(processed_docs) - es_total_batches = (total_docs + batch_size - 1) // batch_size - - for i in range(0, total_docs, batch_size): - es_batch = processed_docs[i:i + batch_size] - es_batch_num = i // batch_size + 1 - - # Store documents and their embeddings for this Elasticsearch batch - doc_embedding_pairs = [] - - # Sub-batch for embedding API - embedding_batch_size = self.max_texts_per_batch - for j in range(0, len(es_batch), embedding_batch_size): - embedding_sub_batch = es_batch[j:j + embedding_batch_size] - - try: - inputs = [doc[content_field] for doc in embedding_sub_batch] - embeddings = embedding_model.get_embeddings(inputs) - - for doc, embedding in zip(embedding_sub_batch, embeddings): - doc_embedding_pairs.append((doc, embedding)) - - except Exception as e: - logger.error(f"Embedding API error: {e}, ES batch num: {es_batch_num}, sub-batch start: {j}, size: {len(embedding_sub_batch)}") - continue - - # Perform a single bulk insert for the entire Elasticsearch batch - if not doc_embedding_pairs: - logger.warning(f"No documents with embeddings to index for ES batch {es_batch_num}") - continue - - operations = [] - for doc, embedding in doc_embedding_pairs: - operations.append({"index": {"_index": index_name}}) - doc["embedding"] = embedding - if "embedding_model_name" not in doc: - doc["embedding_model_name"] = getattr(embedding_model, 'embedding_model_name', 'unknown') - operations.append(doc) - - try: - response = self.client.bulk( - index=index_name, - operations=operations, - refresh=False - ) - self._handle_bulk_errors(response) - total_indexed += len(doc_embedding_pairs) - logger.info(f"Processed ES batch {es_batch_num}/{es_total_batches}, indexed {len(doc_embedding_pairs)} documents.") - - except Exception as e: - logger.error(f"Bulk insert error: {e}, ES batch num: {es_batch_num}") - continue - - if es_batch_num % 10 == 0: - time.sleep(0.1) - - self._force_refresh_with_retry(index_name) - logger.info(f"Large batch insert completed: {total_indexed} chunks indexed.") - return total_indexed - except Exception as e: - logger.error(f"Large batch insert failed: {e}") - return 0 - - def _preprocess_documents(self, documents: List[Dict[str, Any]], content_field: str) -> List[Dict[str, Any]]: - """Ensure all documents have the required fields and set default values""" - current_time = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()) - current_date = time.strftime('%Y-%m-%d', time.gmtime()) - - processed_docs = [] - for doc in documents: - # Create a copy of the document to avoid modifying the original data - doc_copy = doc.copy() - - # Set create_time if not present - if not doc_copy.get("create_time"): - doc_copy["create_time"] = current_time - - if not doc_copy.get("date"): - doc_copy["date"] = current_date - - # Ensure file_size is present (default to 0 if not provided) - if not doc_copy.get("file_size"): - logger.warning(f"File size not found in {doc_copy}") - doc_copy["file_size"] = 0 - - # Ensure process_source is present - if not doc_copy.get("process_source"): - doc_copy["process_source"] = "Unstructured" - - # Ensure all documents have an ID - if not doc_copy.get("id"): - doc_copy["id"] = f"{int(time.time())}_{hash(doc_copy[content_field])}"[:20] - - processed_docs.append(doc_copy) - - return processed_docs - - def _handle_bulk_errors(self, response: Dict[str, Any]) -> None: - """Handle bulk operation errors""" - if response.get('errors'): - for item in response['items']: - if 'error' in item.get('index', {}): - error_info = item['index']['error'] - error_type = error_info.get('type') - error_reason = error_info.get('reason') - error_cause = error_info.get('caused_by', {}) - - if error_type == 'version_conflict_engine_exception': - # ignore version conflict - continue - else: - logger.error(f"FATAL ERROR {error_type}: {error_reason}") - if error_cause: - logger.error(f"Caused By: {error_cause.get('type')}: {error_cause.get('reason')}") - - def delete_documents_by_path_or_url(self, index_name: str, path_or_url: str) -> int: - """ - Delete documents based on their path_or_url field - - Args: - index_name: Name of the index to delete documents from - path_or_url: The URL or path of the documents to delete - - Returns: - int: Number of documents deleted - """ - try: - result = self.client.delete_by_query( - index=index_name, - body={ - "query": { - "term": { - "path_or_url": path_or_url - } - } - } - ) - logger.info(f"Successfully deleted {result['deleted']} documents with path_or_url: {path_or_url} from index: {index_name}") - return result['deleted'] - except Exception as e: - logger.error(f"Error deleting documents: {str(e)}") - return 0 - - # ---- SEARCH OPERATIONS ---- - - def accurate_search(self, index_names: List[str], query_text: str, top_k: int = 5) -> List[Dict[str, Any]]: - """ - Search for documents using fuzzy text matching across multiple indices. - - Args: - index_names: Name of the index to search in - query_text: The text query to search for - top_k: Number of results to return - - Returns: - List of search results with scores and document content - """ - # Join index names for multi-index search - index_pattern = ",".join(index_names) - - weights = calculate_term_weights(query_text) - - # Prepare the search query using match query for fuzzy matching - search_query = build_weighted_query(query_text, weights) | { - "size": top_k, - "_source": { - "excludes": ["embedding"] - } - } - - # Execute the search across multiple indices - return self.exec_query(index_pattern, search_query) - - def exec_query(self, index_pattern, search_query): - response = self.client.search( - index=index_pattern, - body=search_query - ) - # Process and return results - results = [] - for hit in response["hits"]["hits"]: - results.append({ - "score": hit["_score"], - "document": hit["_source"], - "index": hit["_index"] # Include source index in results - }) - return results - - def semantic_search(self, index_names: List[str], query_text: str, embedding_model: BaseEmbedding, top_k: int = 5) -> List[Dict[str, Any]]: - """ - Search for similar documents using vector similarity across multiple indices. - - Args: - index_names: List of index names to search in - query_text: The text query to search for - embedding_model: The embedding model to use - top_k: Number of results to return - - Returns: - List of search results with scores and document content - """ - # Join index names for multi-index search - index_pattern = ",".join(index_names) - - # Get query embedding - query_embedding = embedding_model.get_embeddings(query_text)[0] - - # Prepare the search query - search_query = { - "knn": { - "field": "embedding", - "query_vector": query_embedding, - "k": top_k, - "num_candidates": top_k * 2, - }, - "size": top_k, - "_source": { - "excludes": ["embedding"] - } - } - - # Execute the search across multiple indices - return self.exec_query(index_pattern, search_query) - - def hybrid_search( - self, - index_names: List[str], - query_text: str, - embedding_model: BaseEmbedding, - top_k: int = 5, - weight_accurate: float = 0.3 - ) -> List[Dict[str, Any]]: - """ - Hybrid search method, combining accurate matching and semantic search results across multiple indices. - - Args: - index_names: List of index names to search in - query_text: The text query to search for - embedding_model: The embedding model to use - top_k: Number of results to return - weight_accurate: The weight of the accurate matching score (0-1), the semantic search weight is 1-weight_accurate - - Returns: - List of search results sorted by combined score - """ - # Get results from both searches - accurate_results = self.accurate_search(index_names, query_text, top_k=top_k) - semantic_results = self.semantic_search(index_names, query_text, embedding_model=embedding_model, top_k=top_k) - - # Create a mapping from document ID to results - combined_results = {} - - # Process accurate matching results - for result in accurate_results: - try: - doc_id = result['document']['id'] - combined_results[doc_id] = { - 'document': result['document'], - 'accurate_score': result.get('score', 0), - 'semantic_score': 0, - 'index': result['index'] # Keep track of source index - } - except KeyError as e: - logger.warning(f"Warning: Missing required field in accurate result: {e}") - continue - - # Process semantic search results - for result in semantic_results: - try: - doc_id = result['document']['id'] - if doc_id in combined_results: - combined_results[doc_id]['semantic_score'] = result.get('score', 0) - else: - combined_results[doc_id] = { - 'document': result['document'], - 'accurate_score': 0, - 'semantic_score': result.get('score', 0), - 'index': result['index'] # Keep track of source index - } - except KeyError as e: - logger.warning(f"Warning: Missing required field in semantic result: {e}") - continue - - # Calculate maximum scores - max_accurate = max([r.get('score', 0) for r in accurate_results]) if accurate_results else 1 - max_semantic = max([r.get('score', 0) for r in semantic_results]) if semantic_results else 1 - - # Calculate combined scores and sort - results = [] - for doc_id, result in combined_results.items(): - try: - # Get scores safely - accurate_score = result.get('accurate_score', 0) - semantic_score = result.get('semantic_score', 0) - - # Normalize scores - normalized_accurate = accurate_score / max_accurate if max_accurate > 0 else 0 - normalized_semantic = semantic_score / max_semantic if max_semantic > 0 else 0 - - # Calculate weighted combined score - combined_score = (weight_accurate * normalized_accurate + - (1 - weight_accurate) * normalized_semantic) - - results.append({ - 'score': combined_score, - 'document': result['document'], - 'index': result['index'], # Include source index in results - 'scores': { - 'accurate': normalized_accurate, - 'semantic': normalized_semantic - } - }) - except KeyError as e: - logger.warning(f"Warning: Error processing result for doc_id {doc_id}: {e}") - continue - - # Sort by combined score and return top k results - results.sort(key=lambda x: x['score'], reverse=True) - return results[:top_k] - - # ---- STATISTICS AND MONITORING ---- - def get_file_list_with_details(self, index_name: str) -> List[Dict[str, Any]]: - """ - Get a list of unique path_or_url values with their file_size and create_time - - Args: - index_name: Name of the index to query - - Returns: - List of dictionaries with path_or_url, file_size, and create_time - """ - agg_query = { - "size": 0, - "aggs": { - "unique_sources": { - "terms": { - "field": "path_or_url", - "size": 1000 # Limit to 1000 files for performance - }, - "aggs": { - "file_sample": { - "top_hits": { - "size": 1, - "_source": ["path_or_url", "file_size", "create_time", "filename"] - } - } - } - } - } - } - - try: - result = self.client.search( - index=index_name, - body=agg_query - ) - - file_list = [] - for bucket in result['aggregations']['unique_sources']['buckets']: - source = bucket['file_sample']['hits']['hits'][0]['_source'] - file_info = { - "path_or_url": source["path_or_url"], - "filename": source.get("filename", ""), - "file_size": source.get("file_size", 0), - "create_time": source.get("create_time", None) - } - file_list.append(file_info) - - return file_list - except Exception as e: - logger.error(f"Error getting file list: {str(e)}") - return [] - - def get_index_mapping(self, index_names: List[str]) -> Dict[str, List[str]]: - """Get field mappings for multiple indices""" - mappings = {} - for index_name in index_names: - try: - mapping = self.client.indices.get_mapping(index=index_name) - if mapping[index_name].get('mappings') and mapping[index_name]['mappings'].get('properties'): - mappings[index_name] = list(mapping[index_name]['mappings']['properties'].keys()) - else: - mappings[index_name] = [] - except Exception as e: - logger.error(f"Error getting mapping for index {index_name}: {str(e)}") - mappings[index_name] = [] - return mappings - - def get_index_stats(self, index_names: List[str], embedding_dim: Optional[int] = None) -> Dict[str, Dict[str, Dict[str, Any]]]: - """Get formatted statistics for multiple indices""" - all_stats = {} - for index_name in index_names: - try: - stats = self.client.indices.stats(index=index_name) - settings = self.client.indices.get_settings(index=index_name) - - # Merge query - agg_query = { - "size": 0, - "aggs": { - "unique_path_or_url_count": { - "cardinality": { - "field": "path_or_url" - } - }, - "process_sources": { - "terms": { - "field": "process_source", - "size": 10 - } - }, - "embedding_models": { - "terms": { - "field": "embedding_model_name", - "size": 10 - } - } - } - } - - # Execute query - agg_result = self.client.search( - index=index_name, - body=agg_query - ) - - unique_sources_count = agg_result['aggregations']['unique_path_or_url_count']['value'] - process_source = agg_result['aggregations']['process_sources']['buckets'][0]['key'] if agg_result['aggregations']['process_sources']['buckets'] else "" - embedding_model = agg_result['aggregations']['embedding_models']['buckets'][0]['key'] if agg_result['aggregations']['embedding_models']['buckets'] else "" - - index_stats = stats["indices"][index_name]["primaries"] - - # Get creation and update timestamps from settings - creation_date = int(settings[index_name]['settings']['index']['creation_date']) - # Update time defaults to creation time if not modified - update_time = creation_date - - all_stats[index_name] = { - "base_info": { - "doc_count": unique_sources_count, - "chunk_count": index_stats["docs"]["count"], - "store_size": format_size(index_stats["store"]["size_in_bytes"]), - "process_source": process_source, - "embedding_model": embedding_model, - "embedding_dim": embedding_dim or 1024, - "creation_date": creation_date, - "update_date": update_time - }, - "search_performance": { - "total_search_count": index_stats["search"]["query_total"], - "hit_count": index_stats["request_cache"]["hit_count"], - } - } - except Exception as e: - logger.error(f"Error getting stats for index {index_name}: {str(e)}") - all_stats[index_name] = {"error": str(e)} - - return all_stats +import time +import logging +import threading +from typing import List, Dict, Any, Optional +from contextlib import contextmanager +from dataclasses import dataclass +from datetime import datetime, timedelta +from ..core.models.embedding_model import BaseEmbedding +from .utils import format_size, format_timestamp, build_weighted_query +from elasticsearch import Elasticsearch, exceptions + +from ..core.nlp.tokenizer import calculate_term_weights + +logger = logging.getLogger("elasticsearch_core") + +@dataclass +class BulkOperation: + """Bulk operation status tracking""" + index_name: str + operation_id: str + start_time: datetime + expected_duration: timedelta + +class ElasticSearchCore: + """ + Core class for Elasticsearch operations including: + - Index management + - Document insertion with embeddings + - Document deletion + - Accurate text search + - Semantic vector search + - Hybrid search + - Index statistics + """ + + def __init__( + self, + host: Optional[str], + api_key: Optional[str], + verify_certs: bool = False, + ssl_show_warn: bool = False, + ): + """ + Initialize ElasticSearchCore with Elasticsearch client and JinaEmbedding model. + + Args: + host: Elasticsearch host URL (defaults to env variable) + api_key: Elasticsearch API key (defaults to env variable) + verify_certs: Whether to verify SSL certificates + ssl_show_warn: Whether to show SSL warnings + """ + # Get credentials from environment if not provided + self.host = host + self.api_key = api_key + + # Initialize Elasticsearch client with HTTPS support + self.client = Elasticsearch( + self.host, + api_key=self.api_key, + verify_certs=verify_certs, + ssl_show_warn=ssl_show_warn, + request_timeout=20, + max_retries=3, # Reduce retries for faster failure detection + retry_on_timeout=True, + retry_on_status=[502, 503, 504], # Retry on these status codes, + ) + + # Initialize embedding model + self._bulk_operations: Dict[str, List[BulkOperation]] = {} + self._settings_lock = threading.Lock() + self._operation_counter = 0 + + # Embedding API limits + self.max_texts_per_batch = 2048 + self.max_tokens_per_text = 8192 + self.max_total_tokens = 100000 + self.max_retries = 3 # Number of retries for failed embedding batches + + # ---- INDEX MANAGEMENT ---- + + def create_vector_index(self, index_name: str, embedding_dim: Optional[int] = None) -> bool: + """ + Create a new vector search index with appropriate mappings in a celery-friendly way. + + Args: + index_name: Name of the index to create + embedding_dim: Dimension of the embedding vectors (optional, will use model's dim if not provided) + + Returns: + bool: True if creation was successful + """ + try: + # Use provided embedding_dim or get from model + actual_embedding_dim = embedding_dim or 1024 + + # Use balanced fixed settings to avoid dynamic adjustment + settings = { + "number_of_shards": 1, + "number_of_replicas": 0, + "refresh_interval": "5s", + "index": { + "max_result_window": 50000, + "translog": { + "durability": "async", + "sync_interval": "5s" + }, + "write": { + "wait_for_active_shards": "1" + }, + # Memory optimization for bulk operations + "merge": { + "policy": { + "max_merge_at_once": 5, + "segments_per_tier": 5 + } + } + } + } + + # Check if index already exists + if self.client.indices.exists(index=index_name): + logger.info(f"Index {index_name} already exists, skipping creation") + self._ensure_index_ready(index_name) + return True + + # Define the mapping with vector field + mappings = { + "properties": { + "id": {"type": "keyword"}, + "title": {"type": "text"}, + "filename": {"type": "keyword"}, + "path_or_url": {"type": "keyword"}, + "language": {"type": "keyword"}, + "author": {"type": "keyword"}, + "date": {"type": "date"}, + "content": {"type": "text"}, + "process_source": {"type": "keyword"}, + "embedding_model_name": {"type": "keyword"}, + "file_size": {"type": "long"}, + "create_time": {"type": "date"}, + "embedding": { + "type": "dense_vector", + "dims": actual_embedding_dim, + "index": "true", + "similarity": "cosine", + }, + } + } + + # Create the index with the defined mappings + self.client.indices.create( + index=index_name, + mappings=mappings, + settings=settings, + wait_for_active_shards="1" + ) + + # Force refresh to ensure visibility + self._force_refresh_with_retry(index_name) + self._ensure_index_ready(index_name) + + logger.info(f"Successfully created index: {index_name}") + return True + + except exceptions.RequestError as e: + # Handle the case where index already exists (error 400) + if "resource_already_exists_exception" in str(e): + logger.info(f"Index {index_name} already exists, skipping creation") + self._ensure_index_ready(index_name) + return True + logger.error(f"Error creating index: {str(e)}") + return False + except Exception as e: + logger.error(f"Error creating index: {str(e)}") + return False + + def _force_refresh_with_retry(self, index_name: str, max_retries: int = 3) -> bool: + """ + Force refresh with retry - synchronous version + """ + for attempt in range(max_retries): + try: + self.client.indices.refresh(index=index_name) + return True + except Exception as e: + if attempt < max_retries - 1: + time.sleep(0.5 * (attempt + 1)) + continue + logger.error(f"Failed to refresh index {index_name}: {e}") + return False + return False + + def _ensure_index_ready(self, index_name: str, timeout: int = 10) -> bool: + """ + Ensure index is ready, avoid 503 error - synchronous version + """ + start_time = time.time() + + while time.time() - start_time < timeout: + try: + # Check cluster health + health = self.client.cluster.health( + index=index_name, + wait_for_status="yellow", + timeout="1s" + ) + + if health["status"] in ["green", "yellow"]: + # Double check: try simple query + self.client.search( + index=index_name, + body={"query": {"match_all": {}}, "size": 0} + ) + return True + + except Exception as e: + time.sleep(0.1) + + logger.warning(f"Index {index_name} may not be fully ready after {timeout}s") + return False + + @contextmanager + def bulk_operation_context(self, index_name: str, estimated_duration: int = 60): + """ + Celery-friendly context manager - using threading.Lock + """ + operation_id = f"bulk_{self._operation_counter}_{threading.current_thread().name}" + self._operation_counter += 1 + + operation = BulkOperation( + index_name=index_name, + operation_id=operation_id, + start_time=datetime.now(), + expected_duration=timedelta(seconds=estimated_duration) + ) + + with self._settings_lock: + # Record current operation + if index_name not in self._bulk_operations: + self._bulk_operations[index_name] = [] + self._bulk_operations[index_name].append(operation) + + # If this is the first bulk operation, adjust settings + if len(self._bulk_operations[index_name]) == 1: + self._apply_bulk_settings(index_name) + + try: + yield operation_id + finally: + with self._settings_lock: + # Remove operation record + self._bulk_operations[index_name] = [ + op for op in self._bulk_operations[index_name] + if op.operation_id != operation_id + ] + + # If there are no other bulk operations, restore settings + if not self._bulk_operations[index_name]: + self._restore_normal_settings(index_name) + del self._bulk_operations[index_name] + + def _apply_bulk_settings(self, index_name: str): + """Apply bulk operation optimization settings""" + try: + self.client.indices.put_settings( + index=index_name, + body={ + "refresh_interval": "30s", + "translog.durability": "async", + "translog.sync_interval": "10s" + } + ) + logger.debug(f"Applied bulk settings to {index_name}") + except Exception as e: + logger.warning(f"Failed to apply bulk settings: {e}") + + def _restore_normal_settings(self, index_name: str): + """Restore normal settings""" + try: + self.client.indices.put_settings( + index=index_name, + body={ + "refresh_interval": "5s", + "translog.durability": "request" + } + ) + # Refresh after restoration + self._force_refresh_with_retry(index_name) + logger.info(f"Restored normal settings for {index_name}") + except Exception as e: + logger.warning(f"Failed to restore settings: {e}") + + def delete_index(self, index_name: str) -> bool: + """ + Delete an entire index + + Args: + index_name: Name of the index to delete + + Returns: + bool: True if deletion was successful + """ + try: + self.client.indices.delete(index=index_name) + logger.info(f"Successfully deleted the index: {index_name}") + return True + except exceptions.NotFoundError: + logger.info(f"Index {index_name} not found") + return False + except Exception as e: + logger.error(f"Error deleting index: {str(e)}") + return False + + def get_user_indices(self, index_pattern: str = "*") -> List[str]: + """ + Get list of user created indices (excluding system indices) + + Args: + index_pattern: Pattern to match index names + + Returns: + List of index names + """ + try: + indices = self.client.indices.get_alias(index=index_pattern) + # Filter out system indices (starting with '.') + return [index_name for index_name in indices.keys() if not index_name.startswith('.')] + except Exception as e: + logger.error(f"Error getting user indices: {str(e)}") + return [] + + # ---- DOCUMENT OPERATIONS ---- + + def index_documents( + self, + index_name: str, + embedding_model: BaseEmbedding, + documents: List[Dict[str, Any]], + batch_size: int = 64, + content_field: str = "content" + ) -> int: + """ + Smart batch insertion - automatically selecting strategy based on data size + + Args: + index_name: Name of the index to add documents to + embedding_model: Model used to generate embeddings for documents + documents: List of document dictionaries + batch_size: Number of documents to process at once + content_field: Field to use for generating embeddings + + Returns: + int: Number of documents successfully indexed + """ + logger.info(f"Indexing {len(documents)} chunks to {index_name}") + + # Handle empty documents list + if not documents: + return 0 + + # Smart strategy selection + total_docs = len(documents) + if total_docs < 64: + # Small data: direct insertion, using wait_for refresh + return self._small_batch_insert(index_name, documents, content_field, embedding_model) + else: + # Large data: using context manager + estimated_duration = max(60, total_docs // 100) + with self.bulk_operation_context(index_name, estimated_duration): + return self._large_batch_insert(index_name, documents, batch_size, content_field, embedding_model) + + def _small_batch_insert(self, index_name: str, documents: List[Dict[str, Any]], content_field: str, embedding_model:BaseEmbedding) -> int: + """Small batch insertion: real-time""" + try: + # Preprocess documents + processed_docs = self._preprocess_documents(documents, content_field) + + # Get embeddings + inputs = [doc[content_field] for doc in processed_docs] + embeddings = embedding_model.get_embeddings(inputs) + + # Prepare bulk operations + operations = [] + for doc, embedding in zip(processed_docs, embeddings): + operations.append({"index": {"_index": index_name}}) + doc["embedding"] = embedding + if "embedding_model_name" not in doc: + doc["embedding_model_name"] = embedding_model.embedding_model_name + operations.append(doc) + + # Execute bulk insertion, wait for refresh to complete + response = self.client.bulk( + index=index_name, + operations=operations, + refresh='wait_for' + ) + + # Handle errors + self._handle_bulk_errors(response) + + logger.info(f"Small batch insert completed: {len(documents)} chunks indexed.") + return len(documents) + + except Exception as e: + logger.error(f"Small batch insert failed: {e}") + return 0 + + def _large_batch_insert(self, index_name: str, documents: List[Dict[str, Any]], batch_size: int, content_field: str, embedding_model: BaseEmbedding) -> int: + """ + Large batch insertion with sub-batching for embedding API. + Splits large document batches into smaller chunks to respect embedding API limits before bulk inserting into Elasticsearch. + """ + try: + processed_docs = self._preprocess_documents(documents, content_field) + total_indexed = 0 + total_docs = len(processed_docs) + es_total_batches = (total_docs + batch_size - 1) // batch_size + start_time = time.time() + + logger.info( + f"=== [INDEXING START] Total chunks: {total_docs}, ES batch size: {batch_size}, Total ES batches: {es_total_batches} ===") + + for i in range(0, total_docs, batch_size): + es_batch = processed_docs[i:i + batch_size] + es_batch_num = i // batch_size + 1 + es_batch_start_time = time.time() + + # Store documents and their embeddings for this Elasticsearch batch + doc_embedding_pairs = [] + + # Sub-batch for embedding API + embedding_batch_size = 64 + for j in range(0, len(es_batch), embedding_batch_size): + embedding_sub_batch = es_batch[j:j + embedding_batch_size] + # Retry logic for embedding API call (3 retries, 1s delay) + # Note: embedding_model.get_embeddings() already has built-in retries with exponential backoff + # This outer retry handles additional failures + max_retries = 3 + retry_delay = 1.0 + success = False + + for retry_attempt in range(max_retries): + try: + inputs = [doc[content_field] + for doc in embedding_sub_batch] + embeddings = embedding_model.get_embeddings(inputs) + + for doc, embedding in zip(embedding_sub_batch, embeddings): + doc_embedding_pairs.append((doc, embedding)) + + success = True + break # Success, exit retry loop + + except Exception as e: + if retry_attempt < max_retries - 1: + logger.warning( + f"Embedding API error (attempt {retry_attempt + 1}/{max_retries}): {e}, ES batch num: {es_batch_num}, sub-batch start: {j}, size: {len(embedding_sub_batch)}. Retrying in {retry_delay}s...") + time.sleep(retry_delay) + else: + logger.error( + f"Embedding API error after {max_retries} attempts: {e}, ES batch num: {es_batch_num}, sub-batch start: {j}, size: {len(embedding_sub_batch)}") + + if not success: + # Skip this sub-batch after all retries failed + continue + + # Perform a single bulk insert for the entire Elasticsearch batch + if not doc_embedding_pairs: + logger.warning(f"No documents with embeddings to index for ES batch {es_batch_num}") + continue + + operations = [] + for doc, embedding in doc_embedding_pairs: + operations.append({"index": {"_index": index_name}}) + doc["embedding"] = embedding + if "embedding_model_name" not in doc: + doc["embedding_model_name"] = getattr(embedding_model, 'embedding_model_name', 'unknown') + operations.append(doc) + + try: + response = self.client.bulk( + index=index_name, + operations=operations, + refresh=False + ) + self._handle_bulk_errors(response) + total_indexed += len(doc_embedding_pairs) + es_batch_elapsed = time.time() - es_batch_start_time + logger.info( + f"[ES BATCH {es_batch_num}/{es_total_batches}] Indexed {len(doc_embedding_pairs)} documents in {es_batch_elapsed:.2f}s. Total progress: {total_indexed}/{total_docs}") + + except Exception as e: + logger.error(f"Bulk insert error: {e}, ES batch num: {es_batch_num}") + continue + + # Add 0.1s delay between batches to avoid overloading embedding API + time.sleep(0.1) + + self._force_refresh_with_retry(index_name) + total_elapsed = time.time() - start_time + logger.info( + f"=== [INDEXING COMPLETE] Successfully indexed {total_indexed}/{total_docs} chunks in {total_elapsed:.2f}s (avg: {total_elapsed/es_total_batches:.2f}s/batch) ===") + return total_indexed + except Exception as e: + logger.error(f"Large batch insert failed: {e}") + return 0 + + def _preprocess_documents(self, documents: List[Dict[str, Any]], content_field: str) -> List[Dict[str, Any]]: + """Ensure all documents have the required fields and set default values""" + current_time = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()) + current_date = time.strftime('%Y-%m-%d', time.gmtime()) + + processed_docs = [] + for doc in documents: + # Create a copy of the document to avoid modifying the original data + doc_copy = doc.copy() + + # Set create_time if not present + if not doc_copy.get("create_time"): + doc_copy["create_time"] = current_time + + if not doc_copy.get("date"): + doc_copy["date"] = current_date + + # Ensure file_size is present (default to 0 if not provided) + if not doc_copy.get("file_size"): + logger.warning(f"File size not found in {doc_copy}") + doc_copy["file_size"] = 0 + + # Ensure process_source is present + if not doc_copy.get("process_source"): + doc_copy["process_source"] = "Unstructured" + + # Ensure all documents have an ID + if not doc_copy.get("id"): + doc_copy["id"] = f"{int(time.time())}_{hash(doc_copy[content_field])}"[:20] + + processed_docs.append(doc_copy) + + return processed_docs + + def _handle_bulk_errors(self, response: Dict[str, Any]) -> None: + """Handle bulk operation errors""" + if response.get('errors'): + for item in response['items']: + if 'error' in item.get('index', {}): + error_info = item['index']['error'] + error_type = error_info.get('type') + error_reason = error_info.get('reason') + error_cause = error_info.get('caused_by', {}) + + if error_type == 'version_conflict_engine_exception': + # ignore version conflict + continue + else: + logger.error(f"FATAL ERROR {error_type}: {error_reason}") + if error_cause: + logger.error(f"Caused By: {error_cause.get('type')}: {error_cause.get('reason')}") + + def delete_documents_by_path_or_url(self, index_name: str, path_or_url: str) -> int: + """ + Delete documents based on their path_or_url field + + Args: + index_name: Name of the index to delete documents from + path_or_url: The URL or path of the documents to delete + + Returns: + int: Number of documents deleted + """ + try: + result = self.client.delete_by_query( + index=index_name, + body={ + "query": { + "term": { + "path_or_url": path_or_url + } + } + } + ) + logger.info(f"Successfully deleted {result['deleted']} documents with path_or_url: {path_or_url} from index: {index_name}") + return result['deleted'] + except Exception as e: + logger.error(f"Error deleting documents: {str(e)}") + return 0 + + # ---- SEARCH OPERATIONS ---- + + def accurate_search(self, index_names: List[str], query_text: str, top_k: int = 5) -> List[Dict[str, Any]]: + """ + Search for documents using fuzzy text matching across multiple indices. + + Args: + index_names: Name of the index to search in + query_text: The text query to search for + top_k: Number of results to return + + Returns: + List of search results with scores and document content + """ + # Join index names for multi-index search + index_pattern = ",".join(index_names) + + weights = calculate_term_weights(query_text) + + # Prepare the search query using match query for fuzzy matching + search_query = build_weighted_query(query_text, weights) | { + "size": top_k, + "_source": { + "excludes": ["embedding"] + } + } + + # Execute the search across multiple indices + return self.exec_query(index_pattern, search_query) + + def exec_query(self, index_pattern, search_query): + response = self.client.search( + index=index_pattern, + body=search_query + ) + # Process and return results + results = [] + for hit in response["hits"]["hits"]: + results.append({ + "score": hit["_score"], + "document": hit["_source"], + "index": hit["_index"] # Include source index in results + }) + return results + + def semantic_search(self, index_names: List[str], query_text: str, embedding_model: BaseEmbedding, top_k: int = 5) -> List[Dict[str, Any]]: + """ + Search for similar documents using vector similarity across multiple indices. + + Args: + index_names: List of index names to search in + query_text: The text query to search for + embedding_model: The embedding model to use + top_k: Number of results to return + + Returns: + List of search results with scores and document content + """ + # Join index names for multi-index search + index_pattern = ",".join(index_names) + + # Get query embedding + query_embedding = embedding_model.get_embeddings(query_text)[0] + + # Prepare the search query + search_query = { + "knn": { + "field": "embedding", + "query_vector": query_embedding, + "k": top_k, + "num_candidates": top_k * 2, + }, + "size": top_k, + "_source": { + "excludes": ["embedding"] + } + } + + # Execute the search across multiple indices + return self.exec_query(index_pattern, search_query) + + def hybrid_search( + self, + index_names: List[str], + query_text: str, + embedding_model: BaseEmbedding, + top_k: int = 5, + weight_accurate: float = 0.3 + ) -> List[Dict[str, Any]]: + """ + Hybrid search method, combining accurate matching and semantic search results across multiple indices. + + Args: + index_names: List of index names to search in + query_text: The text query to search for + embedding_model: The embedding model to use + top_k: Number of results to return + weight_accurate: The weight of the accurate matching score (0-1), the semantic search weight is 1-weight_accurate + + Returns: + List of search results sorted by combined score + """ + # Get results from both searches + accurate_results = self.accurate_search(index_names, query_text, top_k=top_k) + semantic_results = self.semantic_search(index_names, query_text, embedding_model=embedding_model, top_k=top_k) + + # Create a mapping from document ID to results + combined_results = {} + + # Process accurate matching results + for result in accurate_results: + try: + doc_id = result['document']['id'] + combined_results[doc_id] = { + 'document': result['document'], + 'accurate_score': result.get('score', 0), + 'semantic_score': 0, + 'index': result['index'] # Keep track of source index + } + except KeyError as e: + logger.warning(f"Warning: Missing required field in accurate result: {e}") + continue + + # Process semantic search results + for result in semantic_results: + try: + doc_id = result['document']['id'] + if doc_id in combined_results: + combined_results[doc_id]['semantic_score'] = result.get('score', 0) + else: + combined_results[doc_id] = { + 'document': result['document'], + 'accurate_score': 0, + 'semantic_score': result.get('score', 0), + 'index': result['index'] # Keep track of source index + } + except KeyError as e: + logger.warning(f"Warning: Missing required field in semantic result: {e}") + continue + + # Calculate maximum scores + max_accurate = max([r.get('score', 0) for r in accurate_results]) if accurate_results else 1 + max_semantic = max([r.get('score', 0) for r in semantic_results]) if semantic_results else 1 + + # Calculate combined scores and sort + results = [] + for doc_id, result in combined_results.items(): + try: + # Get scores safely + accurate_score = result.get('accurate_score', 0) + semantic_score = result.get('semantic_score', 0) + + # Normalize scores + normalized_accurate = accurate_score / max_accurate if max_accurate > 0 else 0 + normalized_semantic = semantic_score / max_semantic if max_semantic > 0 else 0 + + # Calculate weighted combined score + combined_score = (weight_accurate * normalized_accurate + + (1 - weight_accurate) * normalized_semantic) + + results.append({ + 'score': combined_score, + 'document': result['document'], + 'index': result['index'], # Include source index in results + 'scores': { + 'accurate': normalized_accurate, + 'semantic': normalized_semantic + } + }) + except KeyError as e: + logger.warning(f"Warning: Error processing result for doc_id {doc_id}: {e}") + continue + + # Sort by combined score and return top k results + results.sort(key=lambda x: x['score'], reverse=True) + return results[:top_k] + + # ---- STATISTICS AND MONITORING ---- + def get_file_list_with_details(self, index_name: str) -> List[Dict[str, Any]]: + """ + Get a list of unique path_or_url values with their file_size and create_time + + Args: + index_name: Name of the index to query + + Returns: + List of dictionaries with path_or_url, file_size, and create_time + """ + agg_query = { + "size": 0, + "aggs": { + "unique_sources": { + "terms": { + "field": "path_or_url", + "size": 1000 # Limit to 1000 files for performance + }, + "aggs": { + "file_sample": { + "top_hits": { + "size": 1, + "_source": ["path_or_url", "file_size", "create_time", "filename"] + } + } + } + } + } + } + + try: + result = self.client.search( + index=index_name, + body=agg_query + ) + + file_list = [] + for bucket in result['aggregations']['unique_sources']['buckets']: + source = bucket['file_sample']['hits']['hits'][0]['_source'] + file_info = { + "path_or_url": source["path_or_url"], + "filename": source.get("filename", ""), + "file_size": source.get("file_size", 0), + "create_time": source.get("create_time", None) + } + file_list.append(file_info) + + return file_list + except Exception as e: + logger.error(f"Error getting file list: {str(e)}") + return [] + + def get_index_mapping(self, index_names: List[str]) -> Dict[str, List[str]]: + """Get field mappings for multiple indices""" + mappings = {} + for index_name in index_names: + try: + mapping = self.client.indices.get_mapping(index=index_name) + if mapping[index_name].get('mappings') and mapping[index_name]['mappings'].get('properties'): + mappings[index_name] = list(mapping[index_name]['mappings']['properties'].keys()) + else: + mappings[index_name] = [] + except Exception as e: + logger.error(f"Error getting mapping for index {index_name}: {str(e)}") + mappings[index_name] = [] + return mappings + + def get_index_stats(self, index_names: List[str], embedding_dim: Optional[int] = None) -> Dict[str, Dict[str, Dict[str, Any]]]: + """Get formatted statistics for multiple indices""" + all_stats = {} + for index_name in index_names: + try: + stats = self.client.indices.stats(index=index_name) + settings = self.client.indices.get_settings(index=index_name) + + # Merge query + agg_query = { + "size": 0, + "aggs": { + "unique_path_or_url_count": { + "cardinality": { + "field": "path_or_url" + } + }, + "process_sources": { + "terms": { + "field": "process_source", + "size": 10 + } + }, + "embedding_models": { + "terms": { + "field": "embedding_model_name", + "size": 10 + } + } + } + } + + # Execute query + agg_result = self.client.search( + index=index_name, + body=agg_query + ) + + unique_sources_count = agg_result['aggregations']['unique_path_or_url_count']['value'] + process_source = agg_result['aggregations']['process_sources']['buckets'][0]['key'] if agg_result['aggregations']['process_sources']['buckets'] else "" + embedding_model = agg_result['aggregations']['embedding_models']['buckets'][0]['key'] if agg_result['aggregations']['embedding_models']['buckets'] else "" + + index_stats = stats["indices"][index_name]["primaries"] + + # Get creation and update timestamps from settings + creation_date = int(settings[index_name]['settings']['index']['creation_date']) + # Update time defaults to creation time if not modified + update_time = creation_date + + all_stats[index_name] = { + "base_info": { + "doc_count": unique_sources_count, + "chunk_count": index_stats["docs"]["count"], + "store_size": format_size(index_stats["store"]["size_in_bytes"]), + "process_source": process_source, + "embedding_model": embedding_model, + "embedding_dim": embedding_dim or 1024, + "creation_date": creation_date, + "update_date": update_time + }, + "search_performance": { + "total_search_count": index_stats["search"]["query_total"], + "hit_count": index_stats["request_cache"]["hit_count"], + } + } + except Exception as e: + logger.error(f"Error getting stats for index {index_name}: {str(e)}") + all_stats[index_name] = {"error": str(e)} + + return all_stats \ No newline at end of file diff --git a/sdk/pyproject.toml b/sdk/pyproject.toml index cf28471f..453857a1 100644 --- a/sdk/pyproject.toml +++ b/sdk/pyproject.toml @@ -69,12 +69,13 @@ data_process = [ ] performance = [ # OpenTelemetry Core Components - "opentelemetry-api", - "opentelemetry-sdk", + "opentelemetry-api==1.20.0", + "opentelemetry-sdk==1.20.0", + "opentelemetry-semantic-conventions==0.41b0", # OpenTelemetry Instrumentation - "opentelemetry-instrumentation", - "opentelemetry-instrumentation-fastapi", - "opentelemetry-instrumentation-requests", + "opentelemetry-instrumentation==0.41b0", + "opentelemetry-instrumentation-fastapi==0.41b0", + "opentelemetry-instrumentation-requests==0.41b0", # OpenTelemetry Exporters "opentelemetry-exporter-jaeger", "opentelemetry-exporter-prometheus", @@ -102,4 +103,4 @@ lint.select = ["E", "F", "I", "W"] [tool.ruff.lint.isort] known-first-party = ["nexent"] -lines-after-imports = 2 \ No newline at end of file +lines-after-imports = 2 diff --git a/test/backend/data_process/test_tasks.py b/test/backend/data_process/test_tasks.py index 0abc0f91..c2c18cad 100644 --- a/test/backend/data_process/test_tasks.py +++ b/test/backend/data_process/test_tasks.py @@ -210,6 +210,10 @@ def test_process_local_happy_path(monkeypatch, tmp_path): f = tmp_path / "a.txt" f.write_text("content") + # Mock chunks returned by Ray processing + mock_chunks = [{"content": "chunk1", "metadata": {}}, + {"content": "chunk2", "metadata": {}}] + class FakeActor: class P: def __init__(self, *a, **k): @@ -220,26 +224,84 @@ def __init__(self): self.store_chunks_in_redis = types.SimpleNamespace(remote=lambda *a, **k: None) monkeypatch.setattr(tasks, "get_ray_actor", lambda: FakeActor()) + # Mock ray.get to return chunks instead of reference + fake_ray.get_returns = mock_chunks + self = FakeSelf("p1") result = tasks.process(self, source=str(f), source_type="local", chunking_strategy="basic", index_name="idx", original_filename="a.txt") assert result["redis_key"].startswith("dp:p1:chunks") # success state updated twice: STARTED and SUCCESS assert any(s.get("state") == tasks.states.SUCCESS for s in self.states) + # Verify chunks_count is set correctly (not None) + success_state = [s for s in self.states if s.get( + "state") == tasks.states.SUCCESS][0] + assert success_state.get("meta", {}).get("chunks_count") == 2 def test_process_minio_path(monkeypatch): tasks, fake_ray = import_tasks_with_fake_ray(monkeypatch, initialized=True) + # Mock chunks returned by Ray processing + mock_chunks = [{"content": "minio chunk", "metadata": {}}] + class FakeActor: def __init__(self): self.process_file = types.SimpleNamespace(remote=lambda *a, **k: "ref") self.store_chunks_in_redis = types.SimpleNamespace(remote=lambda *a, **k: None) monkeypatch.setattr(tasks, "get_ray_actor", lambda: FakeActor()) + # Mock ray.get to return chunks + fake_ray.get_returns = mock_chunks + self = FakeSelf("m1") result = tasks.process(self, source="http://minio/bucket/x", source_type="minio", chunking_strategy="basic") assert result["redis_key"].startswith("dp:m1:chunks") + # Verify chunks_count is set + success_state = [s for s in self.states if s.get( + "state") == tasks.states.SUCCESS][0] + assert success_state.get("meta", {}).get("chunks_count") == 1 + + +def test_process_large_file_with_many_chunks(monkeypatch, tmp_path): + """Test processing a large file that generates 100+ chunks""" + tasks, fake_ray = import_tasks_with_fake_ray(monkeypatch, initialized=True) + + # Prepare a fake large file + f = tmp_path / "large.pdf" + f.write_text("large content" * 1000) + + # Mock 150 chunks to simulate large file processing + mock_chunks = [{"content": f"chunk_{i}", "metadata": {}} + for i in range(150)] + + class FakeActor: + def __init__(self): + self.process_file = types.SimpleNamespace( + remote=lambda *a, **k: "ref_large") + self.store_chunks_in_redis = types.SimpleNamespace( + remote=lambda *a, **k: None) + + monkeypatch.setattr(tasks, "get_ray_actor", lambda: FakeActor()) + # Mock ray.get to return large chunks + fake_ray.get_returns = mock_chunks + + self = FakeSelf("large1") + + result = tasks.process(self, source=str(f), source_type="local", + chunking_strategy="basic", index_name="idx", original_filename="large.pdf") + + # Verify redis_key is set + assert result["redis_key"].startswith("dp:large1:chunks") + + # Verify chunks_count shows 150 chunks + success_state = [s for s in self.states if s.get( + "state") == tasks.states.SUCCESS][0] + assert success_state.get("meta", {}).get("chunks_count") == 150 + + # Verify processing_time is set + assert "processing_time" in success_state.get("meta", {}) + assert success_state.get("meta", {}).get("processing_time") >= 0 def test_process_raises_on_missing_file(monkeypatch): @@ -724,3 +786,101 @@ def test_forward_empty_chunks_list_warns_and_raises(monkeypatch): tasks.forward(self, processed_data={ "chunks": []}, index_name="idx", source="/a.txt") json.loads(str(ei.value)) + + +def test_process_zero_file_size_speed_calculation(monkeypatch, tmp_path): + """Test that processing_speed_mb_s handles zero file size correctly""" + tasks, fake_ray = import_tasks_with_fake_ray(monkeypatch, initialized=True) + + # Prepare an empty file + f = tmp_path / "empty.txt" + f.write_text("") + + mock_chunks = [{"content": "chunk", "metadata": {}}] + + class FakeActor: + def __init__(self): + self.process_file = types.SimpleNamespace( + remote=lambda *a, **k: "ref") + self.store_chunks_in_redis = types.SimpleNamespace( + remote=lambda *a, **k: None) + + monkeypatch.setattr(tasks, "get_ray_actor", lambda: FakeActor()) + fake_ray.get_returns = mock_chunks + + self = FakeSelf("empty1") + + tasks.process(self, source=str(f), source_type="local", + chunking_strategy="basic", index_name="idx", original_filename="empty.txt") + + # Verify processing_speed_mb_s is 0 for zero-size file (not division by zero) + success_state = [s for s in self.states if s.get( + "state") == tasks.states.SUCCESS][0] + assert success_state.get("meta", {}).get("processing_speed_mb_s") == 0 + + +def test_process_url_source_with_many_chunks(monkeypatch): + """Test processing URL source that generates many chunks""" + tasks, fake_ray = import_tasks_with_fake_ray(monkeypatch, initialized=True) + + # Mock 120 chunks to simulate URL processing + mock_chunks = [{"content": f"url_chunk_{i}", "metadata": {}} + for i in range(120)] + + class FakeActor: + def __init__(self): + self.process_file = types.SimpleNamespace( + remote=lambda *a, **k: "ref_url") + self.store_chunks_in_redis = types.SimpleNamespace( + remote=lambda *a, **k: None) + + monkeypatch.setattr(tasks, "get_ray_actor", lambda: FakeActor()) + fake_ray.get_returns = mock_chunks + + self = FakeSelf("url1") + + result = tasks.process(self, source="http://example.com/doc.pdf", + source_type="minio", chunking_strategy="basic", index_name="idx") + + # Verify chunks_count for URL source + success_state = [s for s in self.states if s.get( + "state") == tasks.states.SUCCESS][0] + assert success_state.get("meta", {}).get("chunks_count") == 120 + assert result["redis_key"].startswith("dp:url1:chunks") + + +def test_forward_large_chunks_batch_success(monkeypatch): + """Test forwarding large batch of chunks (100+) to Elasticsearch""" + tasks, _ = import_tasks_with_fake_ray(monkeypatch) + monkeypatch.setattr(tasks, "ELASTICSEARCH_SERVICE", "http://api") + monkeypatch.setattr(tasks, "get_file_size", lambda *a, **k: 5000) + + # Simulate 150 chunks (large file scenario) + large_chunks = [{"content": f"content_{i}", + "metadata": {"page": i}} for i in range(150)] + + # Mock successful indexing of all chunks + monkeypatch.setattr(tasks, "run_async", lambda coro: { + "success": True, + "total_indexed": 150, + "total_submitted": 150, + "message": "All chunks indexed" + }) + + self = FakeSelf("large_forward") + result = tasks.forward( + self, + processed_data={"chunks": large_chunks}, + index_name="idx", + source="/large.pdf", + source_type="local", + original_filename="large.pdf" + ) + + # Verify all 150 chunks were stored + assert result["chunks_stored"] == 150 + + # Verify SUCCESS state was updated + success_state = [s for s in self.states if s.get( + "state") == tasks.states.SUCCESS][0] + assert success_state.get("meta", {}).get("chunks_stored") == 150 diff --git a/test/backend/database/test_agent_db.py b/test/backend/database/test_agent_db.py index 7db21abe..0cd077f0 100644 --- a/test/backend/database/test_agent_db.py +++ b/test/backend/database/test_agent_db.py @@ -82,6 +82,8 @@ def __init__(self): self.delete_flag = "N" self.enabled = True self.updated_by = None + self.business_logic_model_id = None + self.business_logic_model_name = None class MockAgentRelation: def __init__(self): diff --git a/test/backend/services/test_agent_service.py b/test/backend/services/test_agent_service.py index d3bb0e24..7c42b966 100644 --- a/test/backend/services/test_agent_service.py +++ b/test/backend/services/test_agent_service.py @@ -259,7 +259,8 @@ async def test_get_agent_info_impl_success(mock_search_agent_info, mock_search_t "business_description": "Test agent", "tools": mock_tools, "sub_agent_id_list": mock_sub_agent_ids, - "model_name": None + "model_name": None, + "business_logic_model_name": None } assert result == expected_result mock_search_agent_info.assert_called_once_with(123, "test_tenant") @@ -784,7 +785,8 @@ async def test_get_agent_info_impl_with_model_id_success(mock_search_agent_info, "business_description": "Test agent", "tools": mock_tools, "sub_agent_id_list": mock_sub_agent_ids, - "model_name": "GPT-4" + "model_name": "GPT-4", + "business_logic_model_name": None } assert result == expected_result mock_get_model_by_model_id.assert_called_once_with(456) @@ -835,7 +837,8 @@ async def test_get_agent_info_impl_with_model_id_no_display_name(mock_search_age "business_description": "Test agent", "tools": mock_tools, "sub_agent_id_list": mock_sub_agent_ids, - "model_name": None + "model_name": None, + "business_logic_model_name": None } assert result == expected_result mock_get_model_by_model_id.assert_called_once_with(456) @@ -881,12 +884,229 @@ async def test_get_agent_info_impl_with_model_id_none_model_info(mock_search_age "business_description": "Test agent", "tools": mock_tools, "sub_agent_id_list": mock_sub_agent_ids, - "model_name": None + "model_name": None, + "business_logic_model_name": None } assert result == expected_result mock_get_model_by_model_id.assert_called_once_with(456) +@patch('backend.services.agent_service.get_model_by_model_id') +@patch('backend.services.agent_service.query_sub_agents_id_list') +@patch('backend.services.agent_service.search_tools_for_sub_agent') +@patch('backend.services.agent_service.search_agent_info_by_agent_id') +@pytest.mark.asyncio +async def test_get_agent_info_impl_with_business_logic_model(mock_search_agent_info, mock_search_tools, mock_query_sub_agents_id, mock_get_model_by_model_id): + """ + Test get_agent_info_impl with business_logic_model_id. + + This test verifies that: + 1. The function correctly retrieves business logic model information when business_logic_model_id is not None + 2. It sets business_logic_model_name from the model's display_name + 3. It handles both main model and business logic model correctly + """ + # Setup + mock_agent_info = { + "agent_id": 123, + "model_id": 456, + "business_logic_model_id": 789, + "business_description": "Test agent" + } + mock_search_agent_info.return_value = mock_agent_info + + mock_tools = [{"tool_id": 1, "name": "Tool 1"}] + mock_search_tools.return_value = mock_tools + + mock_sub_agent_ids = [101, 102] + mock_query_sub_agents_id.return_value = mock_sub_agent_ids + + # Mock model info for main model + mock_main_model_info = { + "model_id": 456, + "display_name": "GPT-4", + "provider": "openai" + } + + # Mock model info for business logic model + mock_business_logic_model_info = { + "model_id": 789, + "display_name": "Claude-3.5", + "provider": "anthropic" + } + + # Mock get_model_by_model_id to return different values based on input + def mock_get_model(model_id): + if model_id == 456: + return mock_main_model_info + elif model_id == 789: + return mock_business_logic_model_info + return None + + mock_get_model_by_model_id.side_effect = mock_get_model + + # Execute + result = await get_agent_info_impl(agent_id=123, tenant_id="test_tenant") + + # Assert + expected_result = { + "agent_id": 123, + "model_id": 456, + "business_logic_model_id": 789, + "business_description": "Test agent", + "tools": mock_tools, + "sub_agent_id_list": mock_sub_agent_ids, + "model_name": "GPT-4", + "business_logic_model_name": "Claude-3.5" + } + assert result == expected_result + + # Verify both models were looked up + assert mock_get_model_by_model_id.call_count == 2 + mock_get_model_by_model_id.assert_any_call(456) + mock_get_model_by_model_id.assert_any_call(789) + + +@patch('backend.services.agent_service.get_model_by_model_id') +@patch('backend.services.agent_service.query_sub_agents_id_list') +@patch('backend.services.agent_service.search_tools_for_sub_agent') +@patch('backend.services.agent_service.search_agent_info_by_agent_id') +@pytest.mark.asyncio +async def test_get_agent_info_impl_with_business_logic_model_none(mock_search_agent_info, mock_search_tools, mock_query_sub_agents_id, mock_get_model_by_model_id): + """ + Test get_agent_info_impl with business_logic_model_id but get_model_by_model_id returns None. + + This test verifies that: + 1. The function correctly handles when business_logic_model_id is not None but get_model_by_model_id returns None + 2. It sets business_logic_model_name to None when model_info is None + """ + # Setup + mock_agent_info = { + "agent_id": 123, + "model_id": 456, + "business_logic_model_id": 789, + "business_description": "Test agent" + } + mock_search_agent_info.return_value = mock_agent_info + + mock_tools = [{"tool_id": 1, "name": "Tool 1"}] + mock_search_tools.return_value = mock_tools + + mock_sub_agent_ids = [101, 102] + mock_query_sub_agents_id.return_value = mock_sub_agent_ids + + # Mock model info for main model + mock_main_model_info = { + "model_id": 456, + "display_name": "GPT-4", + "provider": "openai" + } + + # Mock get_model_by_model_id to return None for business_logic_model_id + def mock_get_model(model_id): + if model_id == 456: + return mock_main_model_info + elif model_id == 789: + return None # Business logic model not found + return None + + mock_get_model_by_model_id.side_effect = mock_get_model + + # Execute + result = await get_agent_info_impl(agent_id=123, tenant_id="test_tenant") + + # Assert + expected_result = { + "agent_id": 123, + "model_id": 456, + "business_logic_model_id": 789, + "business_description": "Test agent", + "tools": mock_tools, + "sub_agent_id_list": mock_sub_agent_ids, + "model_name": "GPT-4", + "business_logic_model_name": None # Should be None when model info is not found + } + assert result == expected_result + + # Verify both models were looked up + assert mock_get_model_by_model_id.call_count == 2 + mock_get_model_by_model_id.assert_any_call(456) + mock_get_model_by_model_id.assert_any_call(789) + + +@patch('backend.services.agent_service.get_model_by_model_id') +@patch('backend.services.agent_service.query_sub_agents_id_list') +@patch('backend.services.agent_service.search_tools_for_sub_agent') +@patch('backend.services.agent_service.search_agent_info_by_agent_id') +@pytest.mark.asyncio +async def test_get_agent_info_impl_with_business_logic_model_no_display_name(mock_search_agent_info, mock_search_tools, mock_query_sub_agents_id, mock_get_model_by_model_id): + """ + Test get_agent_info_impl with business_logic_model_id but model has no display_name. + + This test verifies that: + 1. The function correctly retrieves business logic model information when business_logic_model_id is not None + 2. It sets business_logic_model_name to None when model_info exists but has no display_name + """ + # Setup + mock_agent_info = { + "agent_id": 123, + "model_id": 456, + "business_logic_model_id": 789, + "business_description": "Test agent" + } + mock_search_agent_info.return_value = mock_agent_info + + mock_tools = [{"tool_id": 1, "name": "Tool 1"}] + mock_search_tools.return_value = mock_tools + + mock_sub_agent_ids = [101, 102] + mock_query_sub_agents_id.return_value = mock_sub_agent_ids + + # Mock model info for main model + mock_main_model_info = { + "model_id": 456, + "display_name": "GPT-4", + "provider": "openai" + } + + # Mock model info for business logic model without display_name + mock_business_logic_model_info = { + "model_id": 789, + "provider": "anthropic" + # No display_name field + } + + # Mock get_model_by_model_id to return different values based on input + def mock_get_model(model_id): + if model_id == 456: + return mock_main_model_info + elif model_id == 789: + return mock_business_logic_model_info + return None + + mock_get_model_by_model_id.side_effect = mock_get_model + + # Execute + result = await get_agent_info_impl(agent_id=123, tenant_id="test_tenant") + + # Assert + expected_result = { + "agent_id": 123, + "model_id": 456, + "business_logic_model_id": 789, + "business_description": "Test agent", + "tools": mock_tools, + "sub_agent_id_list": mock_sub_agent_ids, + "model_name": "GPT-4", + "business_logic_model_name": None # Should be None when display_name is not in model_info + } + assert result == expected_result + + # Verify both models were looked up + assert mock_get_model_by_model_id.call_count == 2 + mock_get_model_by_model_id.assert_any_call(456) + mock_get_model_by_model_id.assert_any_call(789) + + async def test_list_all_agent_info_impl_success(): """ Test successful retrieval of all agent information. diff --git a/test/backend/services/test_elasticsearch_service.py b/test/backend/services/test_elasticsearch_service.py index 92ac6bbc..46710344 100644 --- a/test/backend/services/test_elasticsearch_service.py +++ b/test/backend/services/test_elasticsearch_service.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock # Mock MinioClient before importing modules that use it from unittest.mock import patch +import numpy as np from fastapi.responses import StreamingResponse @@ -1193,13 +1194,25 @@ def test_summary_index_name(self, mock_get_model_by_model_id, mock_calculate_wei 'model_repo': 'test-repo' } - # Mock get_random_documents - with patch.object(ElasticSearchService, 'get_random_documents') as mock_get_docs: - mock_get_docs.return_value = { - "documents": [ - {"title": "Doc1", "filename": "file1.txt", "content": "Content1"}, - {"title": "Doc2", "filename": "file2.txt", "content": "Content2"} - ] + # Mock the new Map-Reduce functions + with patch('utils.document_vector_utils.process_documents_for_clustering') as mock_process_docs, \ + patch('utils.document_vector_utils.kmeans_cluster_documents') as mock_cluster, \ + patch('utils.document_vector_utils.summarize_clusters_map_reduce') as mock_summarize, \ + patch('utils.document_vector_utils.merge_cluster_summaries') as mock_merge, \ + patch('database.model_management_db.get_model_by_model_id') as mock_get_model_internal: + + # Mock return values + mock_process_docs.return_value = ( + {"doc1": {"chunks": [{"content": "test content"}]}}, # document_samples + {"doc1": np.array([0.1, 0.2, 0.3])} # doc_embeddings + ) + mock_cluster.return_value = {"doc1": 0} # clusters + mock_summarize.return_value = {0: "Test cluster summary"} # cluster_summaries + mock_merge.return_value = "Final merged summary" # final_summary + mock_get_model_internal.return_value = { + 'api_key': 'test_api_key', + 'base_url': 'https://api.test.com', + 'model_name': 'test-model' } # Execute @@ -1228,9 +1241,8 @@ async def run_test(): # Assert self.assertIsInstance(result, StreamingResponse) - mock_get_docs.assert_called_once() - mock_calculate_weights.assert_called_once() - mock_get_model_by_model_id.assert_called_once_with(1, "test_tenant") + # Basic functionality test - just verify the response is correct type + # The detailed function calls are tested in their own unit tests def test_get_random_documents(self): """ @@ -1763,143 +1775,226 @@ def test_check_kb_exist_orphan_in_pg_delete_failure(self, mock_get_knowledge, mo self.assertEqual(result["status"], "error_cleaning_orphans") self.assertTrue(result.get("error")) + # Note: generate_knowledge_summary_stream function has been removed + # These tests are no longer relevant as the function was replaced with summary_index_name + + def test_get_es_core(self): + """ + Test get_es_core function returns the elastic_core instance. + + This test verifies that: + 1. The get_es_core function returns the correct elastic_core instance + 2. The function is properly imported and accessible + """ + from backend.services.elasticsearch_service import get_es_core + + # Execute + result = get_es_core() + + # Assert + self.assertIsNotNone(result) + # The result should be the elastic_core instance + self.assertTrue(hasattr(result, 'client')) + @patch('backend.services.elasticsearch_service.tenant_config_manager') - @patch('database.model_management_db.get_model_by_model_id') - def test_generate_knowledge_summary_stream_model_not_found_fallback(self, mock_get_model_by_model_id, mock_tenant_config_manager): + def test_get_embedding_model_embedding_type(self, mock_tenant_config_manager): + """ + Test get_embedding_model with embedding model type. + + This test verifies that: + 1. When model_type is "embedding", OpenAICompatibleEmbedding is returned + 2. The correct parameters are passed to the embedding model """ - Test generate_knowledge_summary_stream when model_id is provided but model_info is None. - Should fallback to default model configuration. + # Setup + mock_config = { + "model_type": "embedding", + "api_key": "test_api_key", + "base_url": "https://test.api.com", + "model_name": "test-model", + "max_tokens": 1024 + } + mock_tenant_config_manager.get_model_config.return_value = mock_config + + # Stop the mock from setUp to test the real function + self.get_embedding_model_patcher.stop() + + try: + with patch('backend.services.elasticsearch_service.OpenAICompatibleEmbedding') as mock_embedding_class: + mock_embedding_instance = MagicMock() + mock_embedding_class.return_value = mock_embedding_instance + + with patch('backend.services.elasticsearch_service.get_model_name_from_config') as mock_get_model_name: + mock_get_model_name.return_value = "test-model" + + # Execute - now we can call the real function + from backend.services.elasticsearch_service import get_embedding_model + result = get_embedding_model("test_tenant") + + # Assert + self.assertEqual(result, mock_embedding_instance) + mock_tenant_config_manager.get_model_config.assert_called_once_with( + key="EMBEDDING_ID", tenant_id="test_tenant") + mock_embedding_class.assert_called_once_with( + api_key="test_api_key", + base_url="https://test.api.com", + model_name="test-model", + embedding_dim=1024 + ) + finally: + # Restart the mock for other tests + self.get_embedding_model_patcher.start() + + @patch('backend.services.elasticsearch_service.tenant_config_manager') + def test_get_embedding_model_multi_embedding_type(self, mock_tenant_config_manager): + """ + Test get_embedding_model with multi_embedding model type. + + This test verifies that: + 1. When model_type is "multi_embedding", JinaEmbedding is returned + 2. The correct parameters are passed to the embedding model """ # Setup - mock_get_model_by_model_id.return_value = None # Model not found - mock_tenant_config_manager.get_model_config.return_value = { - 'api_key': 'default_api_key', - 'base_url': 'https://default.api.com', - 'model_name': 'default-model' + mock_config = { + "model_type": "multi_embedding", + "api_key": "test_api_key", + "base_url": "https://test.api.com", + "model_name": "test-model", + "max_tokens": 2048 } - - # Mock OpenAI client - with patch('backend.services.elasticsearch_service.OpenAI') as mock_openai: - mock_client = MagicMock() - mock_openai.return_value = mock_client - - # Mock stream response - mock_response = MagicMock() - mock_response.__iter__ = MagicMock(return_value=iter([ - MagicMock(choices=[MagicMock(delta=MagicMock(content="Test"))]), - MagicMock(choices=[MagicMock(delta=MagicMock(content="END"))]) - ])) - mock_client.chat.completions.create.return_value = mock_response - - # Execute - from backend.services.elasticsearch_service import generate_knowledge_summary_stream - result = list(generate_knowledge_summary_stream( - keywords="test keywords", - language="en", - tenant_id="test_tenant", - model_id=999 # Non-existent model ID - )) - + mock_tenant_config_manager.get_model_config.return_value = mock_config + + # Stop the mock from setUp to test the real function + self.get_embedding_model_patcher.stop() + + try: + with patch('backend.services.elasticsearch_service.JinaEmbedding') as mock_embedding_class: + mock_embedding_instance = MagicMock() + mock_embedding_class.return_value = mock_embedding_instance + + with patch('backend.services.elasticsearch_service.get_model_name_from_config') as mock_get_model_name: + mock_get_model_name.return_value = "test-model" + + # Execute - now we can call the real function + from backend.services.elasticsearch_service import get_embedding_model + result = get_embedding_model("test_tenant") + + # Assert + self.assertEqual(result, mock_embedding_instance) + mock_tenant_config_manager.get_model_config.assert_called_once_with( + key="EMBEDDING_ID", tenant_id="test_tenant") + mock_embedding_class.assert_called_once_with( + api_key="test_api_key", + base_url="https://test.api.com", + model_name="test-model", + embedding_dim=2048 + ) + finally: + # Restart the mock for other tests + self.get_embedding_model_patcher.start() + + @patch('backend.services.elasticsearch_service.tenant_config_manager') + def test_get_embedding_model_unknown_type(self, mock_tenant_config_manager): + """ + Test get_embedding_model with unknown model type. + + This test verifies that: + 1. When model_type is neither "embedding" nor "multi_embedding", None is returned + 2. The function handles unknown model types gracefully + """ + # Setup + mock_config = { + "model_type": "unknown_type", + "api_key": "test_api_key", + "base_url": "https://test.api.com", + "model_name": "test-model", + "max_tokens": 1024 + } + mock_tenant_config_manager.get_model_config.return_value = mock_config + + # Stop the mock from setUp to test the real function + self.get_embedding_model_patcher.stop() + + try: + # Execute - now we can call the real function + from backend.services.elasticsearch_service import get_embedding_model + result = get_embedding_model("test_tenant") + # Assert - mock_get_model_by_model_id.assert_called_once_with(999, "test_tenant") + self.assertIsNone(result) mock_tenant_config_manager.get_model_config.assert_called_once_with( - key="LLM_ID", tenant_id="test_tenant" - ) - self.assertEqual(len(result), 3) - self.assertEqual(result[0], "Test") - self.assertEqual(result[1], "END") - self.assertEqual(result[2], "END") + key="EMBEDDING_ID", tenant_id="test_tenant") + finally: + # Restart the mock for other tests + self.get_embedding_model_patcher.start() @patch('backend.services.elasticsearch_service.tenant_config_manager') - @patch('database.model_management_db.get_model_by_model_id') - def test_generate_knowledge_summary_stream_model_exception_fallback(self, mock_get_model_by_model_id, mock_tenant_config_manager): + def test_get_embedding_model_empty_type(self, mock_tenant_config_manager): """ - Test generate_knowledge_summary_stream when getting model info raises an exception. - Should fallback to default model configuration. + Test get_embedding_model with empty model type. + + This test verifies that: + 1. When model_type is empty string, None is returned + 2. The function handles empty model types gracefully """ # Setup - mock_get_model_by_model_id.side_effect = Exception("Database connection error") - mock_tenant_config_manager.get_model_config.return_value = { - 'api_key': 'default_api_key', - 'base_url': 'https://default.api.com', - 'model_name': 'default-model' + mock_config = { + "model_type": "", + "api_key": "test_api_key", + "base_url": "https://test.api.com", + "model_name": "test-model", + "max_tokens": 1024 } - - # Mock OpenAI client - with patch('backend.services.elasticsearch_service.OpenAI') as mock_openai: - mock_client = MagicMock() - mock_openai.return_value = mock_client - - # Mock stream response - mock_response = MagicMock() - mock_response.__iter__ = MagicMock(return_value=iter([ - MagicMock(choices=[MagicMock(delta=MagicMock(content="Test"))]), - MagicMock(choices=[MagicMock(delta=MagicMock(content="END"))]) - ])) - mock_client.chat.completions.create.return_value = mock_response - - # Execute - from backend.services.elasticsearch_service import generate_knowledge_summary_stream - result = list(generate_knowledge_summary_stream( - keywords="test keywords", - language="en", - tenant_id="test_tenant", - model_id=1 - )) - + mock_tenant_config_manager.get_model_config.return_value = mock_config + + # Stop the mock from setUp to test the real function + self.get_embedding_model_patcher.stop() + + try: + # Execute - now we can call the real function + from backend.services.elasticsearch_service import get_embedding_model + result = get_embedding_model("test_tenant") + # Assert - mock_get_model_by_model_id.assert_called_once_with(1, "test_tenant") + self.assertIsNone(result) mock_tenant_config_manager.get_model_config.assert_called_once_with( - key="LLM_ID", tenant_id="test_tenant" - ) - self.assertEqual(len(result), 3) - self.assertEqual(result[0], "Test") - self.assertEqual(result[1], "END") - self.assertEqual(result[2], "END") + key="EMBEDDING_ID", tenant_id="test_tenant") + finally: + # Restart the mock for other tests + self.get_embedding_model_patcher.start() @patch('backend.services.elasticsearch_service.tenant_config_manager') - def test_generate_knowledge_summary_stream_no_model_id_default_config(self, mock_tenant_config_manager): + def test_get_embedding_model_missing_type(self, mock_tenant_config_manager): """ - Test generate_knowledge_summary_stream when model_id is None. - Should use default model configuration. + Test get_embedding_model with missing model type. + + This test verifies that: + 1. When model_type is missing from config, None is returned + 2. The function handles missing model types gracefully """ # Setup - mock_tenant_config_manager.get_model_config.return_value = { - 'api_key': 'default_api_key', - 'base_url': 'https://default.api.com', - 'model_name': 'default-model' + mock_config = { + "api_key": "test_api_key", + "base_url": "https://test.api.com", + "model_name": "test-model", + "max_tokens": 1024 } - - # Mock OpenAI client - with patch('backend.services.elasticsearch_service.OpenAI') as mock_openai: - mock_client = MagicMock() - mock_openai.return_value = mock_client - - # Mock stream response - mock_response = MagicMock() - mock_response.__iter__ = MagicMock(return_value=iter([ - MagicMock(choices=[MagicMock(delta=MagicMock(content="Test"))]), - MagicMock(choices=[MagicMock(delta=MagicMock(content="END"))]) - ])) - mock_client.chat.completions.create.return_value = mock_response - - # Execute - from backend.services.elasticsearch_service import generate_knowledge_summary_stream - result = list(generate_knowledge_summary_stream( - keywords="test keywords", - language="en", - tenant_id="test_tenant", - model_id=None # No model_id provided - )) - + mock_tenant_config_manager.get_model_config.return_value = mock_config + + # Stop the mock from setUp to test the real function + self.get_embedding_model_patcher.stop() + + try: + # Execute - now we can call the real function + from backend.services.elasticsearch_service import get_embedding_model + result = get_embedding_model("test_tenant") + # Assert + self.assertIsNone(result) mock_tenant_config_manager.get_model_config.assert_called_once_with( - key="LLM_ID", tenant_id="test_tenant" - ) - self.assertEqual(len(result), 3) - self.assertEqual(result[0], "Test") - self.assertEqual(result[1], "END") - self.assertEqual(result[2], "END") + key="EMBEDDING_ID", tenant_id="test_tenant") + finally: + # Restart the mock for other tests + self.get_embedding_model_patcher.start() if __name__ == '__main__': diff --git a/test/backend/test_cluster_summarization.py b/test/backend/test_cluster_summarization.py new file mode 100644 index 00000000..9fd0a3b9 --- /dev/null +++ b/test/backend/test_cluster_summarization.py @@ -0,0 +1,180 @@ +""" +Test module for cluster summarization + +Tests for cluster summarization functionality. +""" +import os +import sys +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +# Add backend to path +current_dir = os.path.dirname(os.path.abspath(__file__)) +backend_dir = os.path.abspath(os.path.join(current_dir, "../../backend")) +sys.path.insert(0, backend_dir) + +from backend.utils.document_vector_utils import ( + extract_cluster_content, + summarize_cluster, + summarize_clusters, + merge_cluster_summaries +) + + +class TestClusterSummarization: + """Test cluster summarization functionality""" + + def test_extract_cluster_content_single_doc(self): + """Test extracting content from cluster with single document""" + document_samples = { + 'doc_001': { + 'filename': 'doc1.pdf', + 'chunks': [ + {'content': 'Content chunk 1'}, + {'content': 'Content chunk 2'}, + {'content': 'Content chunk 3'} + ] + } + } + + cluster_doc_ids = ['doc_001'] + content = extract_cluster_content(document_samples, cluster_doc_ids, max_chunks_per_doc=3) + + assert 'doc1.pdf' in content + assert 'Content chunk 1' in content + assert 'Content chunk 2' in content + assert 'Content chunk 3' in content + + def test_extract_cluster_content_multiple_docs(self): + """Test extracting content from cluster with multiple documents""" + document_samples = { + 'doc_001': { + 'filename': 'doc1.pdf', + 'chunks': [ + {'content': 'Content chunk 1'}, + {'content': 'Content chunk 2'} + ] + }, + 'doc_002': { + 'filename': 'doc2.pdf', + 'chunks': [ + {'content': 'Content chunk 3'}, + {'content': 'Content chunk 4'} + ] + } + } + + cluster_doc_ids = ['doc_001', 'doc_002'] + content = extract_cluster_content(document_samples, cluster_doc_ids, max_chunks_per_doc=3) + + assert 'doc1.pdf' in content + assert 'doc2.pdf' in content + assert 'Content chunk 1' in content + assert 'Content chunk 4' in content + + def test_extract_cluster_content_long_chunks(self): + """Test extracting content with long chunks""" + long_content = 'A' * 1000 + document_samples = { + 'doc_001': { + 'filename': 'doc1.pdf', + 'chunks': [ + {'content': long_content} + ] + } + } + + cluster_doc_ids = ['doc_001'] + content = extract_cluster_content(document_samples, cluster_doc_ids, max_chunks_per_doc=3) + + # Content should be truncated + assert len(content) < len(long_content) + 100 + assert '...' in content + + def test_extract_cluster_content_many_chunks(self): + """Test extracting representative chunks when document has many chunks""" + chunks = [{'content': f'Chunk {i}'} for i in range(10)] + document_samples = { + 'doc_001': { + 'filename': 'doc1.pdf', + 'chunks': chunks + } + } + + cluster_doc_ids = ['doc_001'] + content = extract_cluster_content(document_samples, cluster_doc_ids, max_chunks_per_doc=3) + + # Should only include representative chunks (first, middle, last) + assert 'Chunk 0' in content + assert 'Chunk 9' in content + # Middle chunk should be around chunk 4 or 5 + assert 'Chunk 4' in content or 'Chunk 5' in content + + def test_summarize_cluster_placeholder(self): + """Test cluster summarization (placeholder implementation)""" + document_summaries = ["Summary 1", "Summary 2"] + summary = summarize_cluster(document_summaries, language="zh", max_words=150) + + assert summary is not None + assert isinstance(summary, str) + assert 'Cluster Summary' in summary or 'Based on' in summary + + def test_merge_cluster_summaries(self): + """Test merging cluster summaries""" + cluster_summaries = { + 0: "Cluster 0 summary", + 1: "Cluster 1 summary", + 2: "Cluster 2 summary" + } + + merged = merge_cluster_summaries(cluster_summaries) + + assert merged is not None + assert isinstance(merged, str) + assert "Cluster 0 summary" in merged + assert "Cluster 1 summary" in merged + assert "Cluster 2 summary" in merged + + def test_merge_cluster_summaries_empty(self): + """Test merging empty cluster summaries""" + cluster_summaries = {} + merged = merge_cluster_summaries(cluster_summaries) + + assert merged == "" + + def test_summarize_clusters(self): + """Test summarizing multiple clusters""" + document_samples = { + 'doc_001': { + 'filename': 'doc1.pdf', + 'chunks': [{'content': 'Content 1'}] + }, + 'doc_002': { + 'filename': 'doc2.pdf', + 'chunks': [{'content': 'Content 2'}] + }, + 'doc_003': { + 'filename': 'doc3.pdf', + 'chunks': [{'content': 'Content 3'}] + } + } + + clusters = { + 0: ['doc_001', 'doc_002'], + 1: ['doc_003'] + } + + summaries = summarize_clusters(document_samples, clusters, language="zh", max_words=150) + + assert len(summaries) == 2 + assert 0 in summaries + assert 1 in summaries + assert summaries[0] is not None + assert summaries[1] is not None + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) + diff --git a/test/backend/test_document_vector_integration.py b/test/backend/test_document_vector_integration.py new file mode 100644 index 00000000..015818d3 --- /dev/null +++ b/test/backend/test_document_vector_integration.py @@ -0,0 +1,107 @@ +""" +Integration test for document vector operations + +This test demonstrates the complete workflow from ES retrieval to clustering. +Note: This requires a running Elasticsearch instance. +""" +import os +import sys +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +# Add backend to path +current_dir = os.path.dirname(os.path.abspath(__file__)) +backend_dir = os.path.abspath(os.path.join(current_dir, "../../backend")) +sys.path.insert(0, backend_dir) + +from backend.utils.document_vector_utils import ( + calculate_document_embedding, + auto_determine_k, + kmeans_cluster_documents +) + + +class TestDocumentVectorIntegration: + """Integration tests for document vector operations""" + + def test_complete_workflow(self): + """Test complete workflow: embedding calculation -> clustering""" + # Simulate document chunks with embeddings + chunks_1 = [ + {'embedding': np.random.rand(128).tolist(), 'content': 'Content for doc 1 chunk 1'}, + {'embedding': np.random.rand(128).tolist(), 'content': 'Content for doc 1 chunk 2'}, + {'embedding': np.random.rand(128).tolist(), 'content': 'Content for doc 1 chunk 3'} + ] + + chunks_2 = [ + {'embedding': np.random.rand(128).tolist(), 'content': 'Content for doc 2 chunk 1'}, + {'embedding': np.random.rand(128).tolist(), 'content': 'Content for doc 2 chunk 2'} + ] + + chunks_3 = [ + {'embedding': np.random.rand(128).tolist(), 'content': 'Content for doc 3 chunk 1'}, + {'embedding': np.random.rand(128).tolist(), 'content': 'Content for doc 3 chunk 2'}, + {'embedding': np.random.rand(128).tolist(), 'content': 'Content for doc 3 chunk 3'}, + {'embedding': np.random.rand(128).tolist(), 'content': 'Content for doc 3 chunk 4'} + ] + + # Calculate document embeddings + doc_embedding_1 = calculate_document_embedding(chunks_1, use_weighted=True) + doc_embedding_2 = calculate_document_embedding(chunks_2, use_weighted=True) + doc_embedding_3 = calculate_document_embedding(chunks_3, use_weighted=True) + + assert doc_embedding_1 is not None + assert doc_embedding_2 is not None + assert doc_embedding_3 is not None + + # Create document embeddings dictionary + doc_embeddings = { + 'doc_001': doc_embedding_1, + 'doc_002': doc_embedding_2, + 'doc_003': doc_embedding_3 + } + + # Determine optimal K + embeddings_array = np.array([doc_embedding_1, doc_embedding_2, doc_embedding_3]) + optimal_k = auto_determine_k(embeddings_array, min_k=2, max_k=3) + + assert 2 <= optimal_k <= 3 + + # Perform clustering + clusters = kmeans_cluster_documents(doc_embeddings, k=optimal_k) + + assert len(clusters) == optimal_k + assert sum(len(docs) for docs in clusters.values()) == 3 + + def test_large_dataset_clustering(self): + """Test clustering with larger simulated dataset""" + # Create simulated document embeddings + n_docs = 50 + doc_embeddings = { + f'doc_{i:03d}': np.random.rand(128) for i in range(n_docs) + } + + # Auto-determine K + embeddings_array = np.array(list(doc_embeddings.values())) + optimal_k = auto_determine_k(embeddings_array, min_k=3, max_k=15) + + assert 3 <= optimal_k <= 15 + + # Cluster documents + clusters = kmeans_cluster_documents(doc_embeddings, k=optimal_k) + + assert len(clusters) == optimal_k + assert sum(len(docs) for docs in clusters.values()) == n_docs + + # Verify cluster sizes are reasonable + cluster_sizes = [len(docs) for docs in clusters.values()] + assert min(cluster_sizes) >= 1 + # Allow for some imbalance in clustering results (realistic for random data) + assert max(cluster_sizes) <= n_docs * 0.7 # No single cluster dominates too much + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) + diff --git a/test/backend/test_document_vector_utils.py b/test/backend/test_document_vector_utils.py new file mode 100644 index 00000000..d4f8ef43 --- /dev/null +++ b/test/backend/test_document_vector_utils.py @@ -0,0 +1,470 @@ +""" +Test module for document_vector_utils + +Tests for document-level vector operations and clustering functionality. +""" +import os +import sys +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +# Add backend to path +current_dir = os.path.dirname(os.path.abspath(__file__)) +backend_dir = os.path.abspath(os.path.join(current_dir, "../../backend")) +sys.path.insert(0, backend_dir) + +from backend.utils.document_vector_utils import ( + calculate_document_embedding, + auto_determine_k, + kmeans_cluster_documents, + extract_representative_chunks_smart, + summarize_document, + summarize_cluster, + summarize_clusters_map_reduce, + merge_cluster_summaries, + get_documents_from_es, + process_documents_for_clustering, + extract_cluster_content, + analyze_cluster_coherence +) + + +class TestDocumentEmbedding: + """Test document embedding calculation""" + + def test_calculate_document_embedding_simple_average(self): + """Test simple average embedding calculation""" + chunks = [ + {'embedding': [1.0, 2.0, 3.0], 'content': 'Content 1'}, + {'embedding': [4.0, 5.0, 6.0], 'content': 'Content 2'}, + {'embedding': [7.0, 8.0, 9.0], 'content': 'Content 3'} + ] + + result = calculate_document_embedding(chunks, use_weighted=False) + + assert result is not None + assert np.allclose(result, [4.0, 5.0, 6.0]) # Average of all embeddings + + def test_calculate_document_embedding_weighted(self): + """Test weighted average embedding calculation""" + chunks = [ + {'embedding': [1.0, 2.0], 'content': 'Short'}, + {'embedding': [3.0, 4.0], 'content': 'Long content with more words'}, + {'embedding': [5.0, 6.0], 'content': 'Medium length content'} + ] + + result = calculate_document_embedding(chunks, use_weighted=True) + + assert result is not None + assert len(result) == 2 + + def test_calculate_document_embedding_empty_chunks(self): + """Test handling of empty chunks""" + chunks = [] + result = calculate_document_embedding(chunks) + assert result is None + + def test_calculate_document_embedding_no_embeddings(self): + """Test handling of chunks without embeddings""" + chunks = [ + {'content': 'Content 1'}, + {'content': 'Content 2'} + ] + result = calculate_document_embedding(chunks) + assert result is None + + +class TestAutoDetermineK: + """Test automatic K determination""" + + def test_auto_determine_k_small_dataset(self): + """Test K determination for small dataset""" + embeddings = np.random.rand(10, 128) + k = auto_determine_k(embeddings, min_k=3, max_k=15) + + assert 3 <= k <= 15 + + def test_auto_determine_k_large_dataset(self): + """Test K determination for large dataset""" + embeddings = np.random.rand(200, 128) + k = auto_determine_k(embeddings, min_k=3, max_k=15) + + assert 3 <= k <= 15 + + def test_auto_determine_k_very_small_dataset(self): + """Test K determination for very small dataset""" + embeddings = np.random.rand(5, 128) + k = auto_determine_k(embeddings, min_k=3, max_k=15) + + assert k >= 2 + assert k <= 5 + + def test_auto_determine_k_minimum(self): + """Test K determination respects minimum""" + embeddings = np.random.rand(100, 128) + k = auto_determine_k(embeddings, min_k=5, max_k=15) + + assert k >= 5 + + +class TestKMeansClustering: + """Test K-means clustering""" + + def test_kmeans_cluster_documents(self): + """Test basic K-means clustering""" + doc_embeddings = { + 'doc1': np.array([1.0, 1.0]), + 'doc2': np.array([1.1, 1.1]), + 'doc3': np.array([5.0, 5.0]), + 'doc4': np.array([5.1, 5.1]), + 'doc5': np.array([9.0, 9.0]), + 'doc6': np.array([9.1, 9.1]) + } + + clusters = kmeans_cluster_documents(doc_embeddings, k=3) + + assert len(clusters) == 3 + assert sum(len(docs) for docs in clusters.values()) == 6 + + def test_kmeans_cluster_documents_auto_k(self): + """Test K-means clustering with auto-determined K""" + doc_embeddings = { + f'doc{i}': np.random.rand(128) for i in range(50) + } + + clusters = kmeans_cluster_documents(doc_embeddings, k=None) + + assert len(clusters) > 0 + assert sum(len(docs) for docs in clusters.values()) == 50 + + def test_kmeans_cluster_documents_empty(self): + """Test handling of empty embeddings""" + doc_embeddings = {} + clusters = kmeans_cluster_documents(doc_embeddings) + + assert clusters == {} + + def test_kmeans_cluster_documents_single(self): + """Test handling of single document""" + doc_embeddings = { + 'doc1': np.array([1.0, 1.0, 1.0]) + } + clusters = kmeans_cluster_documents(doc_embeddings) + + # Should return single cluster with one document + assert len(clusters) == 1 + assert 0 in clusters + assert len(clusters[0]) == 1 + assert clusters[0][0] == 'doc1' + + +class TestExtractRepresentativeChunksSmart: + """Test smart chunk selection""" + + def test_extract_representative_chunks_smart_basic(self): + """Test basic smart chunk selection""" + chunks = [ + {'content': 'First chunk content'}, + {'content': 'Second chunk content'}, + {'content': 'Third chunk content'}, + {'content': 'Fourth chunk content'} + ] + + result = extract_representative_chunks_smart(chunks, max_chunks=3) + + assert len(result) <= 3 + assert result[0] == chunks[0] # First chunk always included + assert result[-1] == chunks[-1] # Last chunk included + + def test_extract_representative_chunks_smart_import_error(self): + """Test fallback when calculate_term_weights import fails""" + chunks = [ + {'content': 'First chunk content'}, + {'content': 'Second chunk content'}, + {'content': 'Third chunk content'}, + {'content': 'Fourth chunk content'} + ] + + # Mock the import to fail + with patch.dict('sys.modules', {'nexent.core.nlp.tokenizer': None}): + result = extract_representative_chunks_smart(chunks, max_chunks=3) + + # The fallback logic actually returns 3 chunks (first, middle, last) + assert len(result) == 3 + assert result[0] == chunks[0] # First chunk + assert result[-1] == chunks[-1] # Last chunk + + +class TestSummarizeDocument: + """Test document summarization""" + + def test_summarize_document_no_model(self): + """Test document summarization without model""" + result = summarize_document( + document_content="Test content", + filename="test.pdf", + model_id=None, + tenant_id=None + ) + assert isinstance(result, str) + assert "test.pdf" in result + + def test_summarize_document_with_model_placeholder(self): + """Test document summarization with model ID but no actual LLM call""" + result = summarize_document( + document_content="Test content for summarization", + filename="test.pdf", + model_id=999, # Non-existent model + tenant_id="test_tenant" + ) + assert isinstance(result, str) + assert len(result) > 0 + + +class TestSummarizeCluster: + """Test cluster summarization""" + + def test_summarize_cluster_no_model(self): + """Test cluster summarization without model""" + result = summarize_cluster( + document_summaries=["Summary 1", "Summary 2"], + model_id=None, + tenant_id=None + ) + assert isinstance(result, str) + assert "Summary" in result + + def test_summarize_cluster_with_model_placeholder(self): + """Test cluster summarization with model ID but no actual LLM call""" + result = summarize_cluster( + document_summaries=["Summary 1", "Summary 2"], + model_id=999, # Non-existent model + tenant_id="test_tenant" + ) + assert isinstance(result, str) + assert len(result) > 0 + + +class TestSummarizeClustersMapReduce: + """Test map-reduce cluster summarization""" + + def test_summarize_clusters_map_reduce_basic(self): + """Test basic map-reduce summarization""" + document_samples = { + 'doc1': { + 'chunks': [{'content': 'Content 1'}], + 'filename': 'doc1.pdf', + 'path_or_url': '/path/doc1.pdf' + }, + 'doc2': { + 'chunks': [{'content': 'Content 2'}], + 'filename': 'doc2.pdf', + 'path_or_url': '/path/doc2.pdf' + } + } + clusters = {0: ['doc1', 'doc2']} + + with patch('backend.utils.document_vector_utils.summarize_document') as mock_summarize_doc, \ + patch('backend.utils.document_vector_utils.summarize_cluster') as mock_summarize_cluster: + + mock_summarize_doc.return_value = "Document summary" + mock_summarize_cluster.return_value = "Cluster summary" + + result = summarize_clusters_map_reduce( + document_samples=document_samples, + clusters=clusters, + model_id=1, + tenant_id="test_tenant" + ) + + assert isinstance(result, dict) + assert 0 in result + assert result[0] == "Cluster summary" + + def test_summarize_clusters_map_reduce_no_valid_documents(self): + """Test map-reduce when no valid documents in cluster""" + document_samples = { + 'doc1': { + 'chunks': [], + 'filename': 'doc1.pdf' + } + } + clusters = {0: ['doc1']} + + with patch('backend.utils.document_vector_utils.summarize_document') as mock_summarize_doc, \ + patch('backend.utils.document_vector_utils.summarize_cluster') as mock_summarize_cluster: + + mock_summarize_doc.return_value = "" + mock_summarize_cluster.return_value = "Mock cluster summary" + + result = summarize_clusters_map_reduce( + document_samples=document_samples, + clusters=clusters, + model_id=1, + tenant_id="test_tenant" + ) + + assert isinstance(result, dict) + assert 0 in result + assert result[0] == "Mock cluster summary" + + +class TestMergeClusterSummaries: + """Test cluster summary merging""" + + def test_merge_cluster_summaries(self): + """Test merging multiple cluster summaries""" + cluster_summaries = { + 0: "First cluster summary", + 1: "Second cluster summary", + 2: "Third cluster summary" + } + + result = merge_cluster_summaries(cluster_summaries) + + assert isinstance(result, str) + assert "First cluster summary" in result + assert "Second cluster summary" in result + assert "Third cluster summary" in result + assert "

    " in result # Should use HTML p tags + + +class TestGetDocumentsFromEs: + """Test ES document retrieval""" + + def test_get_documents_from_es_mock(self): + """Test ES document retrieval with mocked client""" + mock_es_core = MagicMock() + mock_es_core.client.search.return_value = { + 'hits': { + 'hits': [ + { + '_source': { + 'path_or_url': '/path/doc1.pdf', + 'filename': 'doc1.pdf', + 'content': 'Content 1', + 'embedding': [1.0, 2.0, 3.0] + } + } + ] + }, + 'aggregations': { + 'unique_documents': { + 'buckets': [ + { + 'key': '/path/doc1.pdf', + 'doc_count': 1 + } + ] + } + } + } + + result = get_documents_from_es('test_index', mock_es_core, sample_doc_count=10) + + assert isinstance(result, dict) + # The function returns a dict with document IDs as keys, not 'documents' key + assert len(result) > 0 + # Check that we have document data + first_doc = list(result.values())[0] + assert 'chunks' in first_doc + + +class TestProcessDocumentsForClustering: + """Test document processing for clustering""" + + def test_process_documents_for_clustering_mock(self): + """Test document processing with mocked functions""" + mock_es_core = MagicMock() + mock_es_core.client.search.return_value = { + 'hits': { + 'hits': [ + { + '_source': { + 'path_or_url': '/path/doc1.pdf', + 'filename': 'doc1.pdf', + 'content': 'Content 1', + 'embedding': [1.0, 2.0, 3.0] + } + } + ] + }, + 'aggregations': { + 'unique_documents': { + 'buckets': [ + { + 'key': '/path/doc1.pdf', + 'doc_count': 1 + } + ] + } + } + } + + with patch('backend.utils.document_vector_utils.calculate_document_embedding') as mock_calc_embedding: + mock_calc_embedding.return_value = np.array([1.0, 2.0, 3.0]) + + documents, embeddings = process_documents_for_clustering( + 'test_index', mock_es_core, sample_doc_count=10 + ) + + assert isinstance(documents, dict) + assert isinstance(embeddings, dict) + assert len(documents) == len(embeddings) + + +class TestExtractClusterContent: + """Test cluster content extraction""" + + def test_extract_cluster_content(self): + """Test extracting content from cluster documents""" + document_samples = { + 'doc1': { + 'chunks': [{'content': 'Content 1'}], + 'filename': 'doc1.pdf' + }, + 'doc2': { + 'chunks': [{'content': 'Content 2'}], + 'filename': 'doc2.pdf' + } + } + doc_ids = ['doc1', 'doc2'] + + result = extract_cluster_content(document_samples, doc_ids) + + assert isinstance(result, str) # The function returns a formatted string + assert 'Content 1' in result + assert 'Content 2' in result + assert 'doc1.pdf' in result + assert 'doc2.pdf' in result + + +class TestAnalyzeClusterCoherence: + """Test cluster coherence analysis""" + + def test_analyze_cluster_coherence(self): + """Test cluster coherence analysis""" + document_samples = { + 'doc1': { + 'filename': 'doc1.pdf', + 'path_or_url': '/path/doc1.pdf' + }, + 'doc2': { + 'filename': 'doc2.pdf', + 'path_or_url': '/path/doc2.pdf' + } + } + doc_ids = ['doc1', 'doc2'] + + result = analyze_cluster_coherence(doc_ids, document_samples) + + assert isinstance(result, dict) + assert 'doc_count' in result + assert result['doc_count'] == 2 + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) + diff --git a/test/backend/test_document_vector_utils_coverage.py b/test/backend/test_document_vector_utils_coverage.py new file mode 100644 index 00000000..290cf5d4 --- /dev/null +++ b/test/backend/test_document_vector_utils_coverage.py @@ -0,0 +1,651 @@ +""" +Supplementary test module for document_vector_utils to improve code coverage + +Tests for functions not fully covered in other test files. +""" +import os +import sys +from unittest.mock import MagicMock, patch, mock_open + +import numpy as np +import pytest + +# Add backend to path +current_dir = os.path.dirname(os.path.abspath(__file__)) +backend_dir = os.path.abspath(os.path.join(current_dir, "../../backend")) +sys.path.insert(0, backend_dir) + +from backend.utils.document_vector_utils import ( + get_documents_from_es, + process_documents_for_clustering, + extract_cluster_content, + extract_representative_chunks_smart, + analyze_cluster_coherence, + summarize_document, + summarize_cluster, + summarize_cluster_legacy, + summarize_clusters_map_reduce, + summarize_clusters, + merge_cluster_summaries, + calculate_document_embedding, + auto_determine_k, + kmeans_cluster_documents +) + + +class TestGetDocumentsFromES: + """Test Elasticsearch document retrieval""" + + def test_get_documents_from_es_success(self): + """Test successful document retrieval from ES""" + mock_es_core = MagicMock() + mock_es_core.client.search.return_value = { + 'aggregations': { + 'unique_documents': { + 'buckets': [ + {'key': '/path/doc1.pdf', 'doc_count': 3}, + {'key': '/path/doc2.pdf', 'doc_count': 2} + ] + } + }, + 'hits': { + 'hits': [ + { + '_source': { + 'filename': 'doc1.pdf', + 'content': 'test content', + 'embedding': [0.1, 0.2, 0.3], + 'file_size': 1000 + } + } + ] + } + } + + result = get_documents_from_es('test_index', mock_es_core, sample_doc_count=10) + assert isinstance(result, dict) + assert mock_es_core.client.search.called + + def test_get_documents_from_es_empty(self): + """Test ES retrieval with no documents""" + mock_es_core = MagicMock() + mock_es_core.client.search.return_value = { + 'aggregations': { + 'unique_documents': { + 'buckets': [] + } + } + } + + result = get_documents_from_es('test_index', mock_es_core) + assert result == {} + + def test_get_documents_from_es_error(self): + """Test ES retrieval error handling""" + mock_es_core = MagicMock() + mock_es_core.client.search.side_effect = Exception("ES error") + + with pytest.raises(Exception, match="Failed to retrieve documents from Elasticsearch"): + get_documents_from_es('test_index', mock_es_core) + + +class TestProcessDocumentsForClustering: + """Test document processing for clustering""" + + @patch('backend.utils.document_vector_utils.get_documents_from_es') + @patch('backend.utils.document_vector_utils.calculate_document_embedding') + def test_process_documents_success(self, mock_calc_emb, mock_get_docs): + """Test successful document processing""" + mock_get_docs.return_value = { + 'doc1': { + 'chunks': [{'embedding': [0.1, 0.2, 0.3]}], + 'filename': 'test.pdf' + } + } + mock_calc_emb.return_value = np.array([0.1, 0.2, 0.3]) + + mock_es_core = MagicMock() + docs, embeddings = process_documents_for_clustering('test_index', mock_es_core) + + assert isinstance(docs, dict) + assert isinstance(embeddings, dict) + assert 'doc1' in docs + assert 'doc1' in embeddings + + @patch('backend.utils.document_vector_utils.get_documents_from_es') + def test_process_documents_empty(self, mock_get_docs): + """Test processing with no documents""" + mock_get_docs.return_value = {} + + mock_es_core = MagicMock() + docs, embeddings = process_documents_for_clustering('test_index', mock_es_core) + + assert docs == {} + assert embeddings == {} + + +class TestExtractClusterContent: + """Test cluster content extraction""" + + def test_extract_cluster_content_basic(self): + """Test basic cluster content extraction""" + document_samples = { + 'doc1': { + 'chunks': [ + {'content': 'chunk 1'}, + {'content': 'chunk 2'} + ] + } + } + cluster_doc_ids = ['doc1'] + + result = extract_cluster_content(document_samples, cluster_doc_ids) + assert isinstance(result, str) + assert len(result) > 0 + + def test_extract_representative_chunks_smart(self): + """Test smart chunk extraction""" + chunks = [ + {'content': 'important keyword data'}, + {'content': 'regular content'}, + {'content': 'more keyword information'} + ] + + result = extract_representative_chunks_smart(chunks, max_chunks=2) + assert len(result) <= 2 + assert len(result) > 0 + + def test_extract_representative_chunks_smart_single(self): + """Test smart extraction with single chunk""" + chunks = [ + {'content': 'single chunk content'} + ] + + result = extract_representative_chunks_smart(chunks, max_chunks=1) + assert len(result) == 1 + + +class TestAnalyzeClusterCoherence: + """Test cluster coherence analysis""" + + def test_analyze_cluster_coherence_basic(self): + """Test basic cluster coherence analysis""" + document_samples = { + 'doc1': { + 'filename': 'test1.pdf', + 'chunks': [{'content': 'test content 1'}], + 'file_size': 1000 + }, + 'doc2': { + 'filename': 'test2.pdf', + 'chunks': [{'content': 'test content 2'}], + 'file_size': 2000 + } + } + cluster_doc_ids = ['doc1', 'doc2'] + + result = analyze_cluster_coherence(cluster_doc_ids, document_samples) + assert isinstance(result, dict) + + +class TestSummarizeDocument: + """Test document summarization""" + + def test_summarize_document_no_model(self): + """Test document summarization without model""" + result = summarize_document( + document_content="Test content", + filename="test.pdf", + model_id=None, + tenant_id=None + ) + assert isinstance(result, str) + assert "test.pdf" in result + + def test_summarize_document_with_model_placeholder(self): + """Test document summarization with model ID but no actual LLM call""" + # With model_id and tenant_id, but without actual database connection, + # it should return a placeholder or error message + result = summarize_document( + document_content="Test content for summarization", + filename="test.pdf", + model_id=999, # Non-existent model + tenant_id="test_tenant" + ) + assert isinstance(result, str) + # Either placeholder summary or error handling + assert len(result) > 0 + + +class TestSummarizeCluster: + """Test cluster summarization""" + + def test_summarize_cluster_no_model(self): + """Test cluster summarization without model""" + doc_summaries = ["Summary 1", "Summary 2"] + # Without model, it will return a formatted summary + result = summarize_cluster( + document_summaries=doc_summaries, + model_id=None, + tenant_id=None + ) + assert isinstance(result, str) + # The function returns an error or formatted text, just check it's a string + assert len(result) > 0 + + def test_summarize_cluster_legacy(self): + """Test legacy cluster summarization""" + cluster_content = "Test cluster content" + + result = summarize_cluster_legacy(cluster_content) + assert isinstance(result, str) + + +class TestSummarizeClustersMapReduce: + """Test Map-Reduce cluster summarization""" + + @patch('backend.utils.document_vector_utils.summarize_document') + @patch('backend.utils.document_vector_utils.summarize_cluster') + def test_summarize_clusters_map_reduce(self, mock_sum_cluster, mock_sum_doc): + """Test Map-Reduce summarization""" + document_samples = { + 'doc1': { + 'filename': 'test1.pdf', + 'chunks': [{'content': 'test content 1'}] + }, + 'doc2': { + 'filename': 'test2.pdf', + 'chunks': [{'content': 'test content 2'}] + } + } + # clusters should map cluster_id to list of doc_ids + clusters = {0: ['doc1', 'doc2']} + + mock_sum_doc.return_value = "Doc summary" + mock_sum_cluster.return_value = "Cluster summary" + + result = summarize_clusters_map_reduce( + document_samples=document_samples, + clusters=clusters, + language='en' + ) + + assert isinstance(result, dict) + assert 0 in result + + +class TestMergeClusterSummaries: + """Test cluster summary merging""" + + def test_merge_cluster_summaries_basic(self): + """Test basic cluster summary merging""" + cluster_summaries = { + 0: "Summary for cluster 0", + 1: "Summary for cluster 1" + } + + result = merge_cluster_summaries(cluster_summaries) + assert isinstance(result, str) + assert "Summary for cluster 0" in result + assert "Summary for cluster 1" in result + assert "

    " in result # HTML paragraph tags + + def test_merge_cluster_summaries_empty(self): + """Test merging empty summaries""" + cluster_summaries = { + 0: "", + 1: "Summary for cluster 1" + } + + result = merge_cluster_summaries(cluster_summaries) + assert isinstance(result, str) + assert "Summary for cluster 1" in result + + def test_merge_cluster_summaries_single(self): + """Test merging single cluster summary""" + cluster_summaries = { + 0: "Single cluster summary" + } + + result = merge_cluster_summaries(cluster_summaries) + assert isinstance(result, str) + assert "Single cluster summary" in result + + +class TestAdditionalCoverage: + """Test additional coverage for uncovered code paths""" + + def test_get_documents_from_es_non_list_documents(self): + """Test ES retrieval when all_documents is not a list""" + mock_es_core = MagicMock() + + # Mock the first search call to return a tuple instead of list + mock_es_core.client.search.side_effect = [ + { + 'aggregations': { + 'unique_documents': { + 'buckets': ( # This will trigger the isinstance check + {'key': '/path/doc1.pdf', 'doc_count': 3}, + ) + } + } + }, + { + 'hits': { + 'hits': [ + { + '_source': { + 'filename': 'doc1.pdf', + 'content': 'test content', + 'embedding': [0.1, 0.2, 0.3], + 'file_size': 1000 + } + } + ] + } + } + ] + + result = get_documents_from_es('test_index', mock_es_core) + assert isinstance(result, dict) + + def test_get_documents_from_es_no_chunks(self): + """Test ES retrieval when document has no chunks""" + mock_es_core = MagicMock() + mock_es_core.client.search.side_effect = [ + { + 'aggregations': { + 'unique_documents': { + 'buckets': [ + {'key': '/path/doc1.pdf', 'doc_count': 0} + ] + } + } + }, + { + 'hits': { + 'hits': [] # No chunks + } + } + ] + + result = get_documents_from_es('test_index', mock_es_core) + assert result == {} # Should return empty dict when no chunks + + def test_calculate_document_embedding_exception(self): + """Test calculate_document_embedding with exception""" + chunks = [ + {'content': 'test content', 'embedding': [0.1, 0.2, 0.3]} + ] + + # Mock numpy operations to raise exception + with patch('numpy.array') as mock_array: + mock_array.side_effect = Exception("Numpy error") + + result = calculate_document_embedding(chunks) + assert result is None + + def test_auto_determine_k_small_dataset(self): + """Test auto_determine_k with very small dataset""" + # Create embeddings with only 2 samples (less than min_k=3) + embeddings = np.array([[0.1, 0.2], [0.3, 0.4]]) + + result = auto_determine_k(embeddings, min_k=3, max_k=5) + assert result == 2 # Should return max(2, n_samples) + + def test_auto_determine_k_exception(self): + """Test auto_determine_k with exception during calculation""" + embeddings = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]) + + # Mock silhouette_score to raise exception + with patch('sklearn.metrics.silhouette_score') as mock_silhouette: + mock_silhouette.side_effect = Exception("Silhouette error") + + result = auto_determine_k(embeddings, min_k=2, max_k=3) + # Should use heuristic fallback + assert isinstance(result, int) + assert result >= 2 + + def test_kmeans_cluster_documents_empty(self): + """Test kmeans_cluster_documents with empty embeddings""" + result = kmeans_cluster_documents({}) + assert result == {} + + def test_kmeans_cluster_documents_exception(self): + """Test kmeans_cluster_documents with exception""" + doc_embeddings = { + 'doc1': np.array([0.1, 0.2, 0.3]), + 'doc2': np.array([0.4, 0.5, 0.6]) + } + + # Mock auto_determine_k to raise exception + with patch('backend.utils.document_vector_utils.auto_determine_k') as mock_auto_k: + mock_auto_k.side_effect = Exception("Auto K error") + + with pytest.raises(Exception, match="Failed to cluster documents"): + kmeans_cluster_documents(doc_embeddings) + + def test_process_documents_for_clustering_exception(self): + """Test process_documents_for_clustering with exception""" + mock_es_core = MagicMock() + mock_es_core.client.search.side_effect = Exception("ES error") + + with pytest.raises(Exception, match="Failed to process documents"): + process_documents_for_clustering('test_index', mock_es_core) + + def test_process_documents_for_clustering_no_embeddings(self): + """Test process_documents_for_clustering when some documents fail embedding calculation""" + mock_es_core = MagicMock() + mock_es_core.client.search.return_value = { + 'aggregations': { + 'unique_documents': { + 'buckets': [ + {'key': '/path/doc1.pdf', 'doc_count': 1} + ] + } + }, + 'hits': { + 'hits': [ + { + '_source': { + 'filename': 'doc1.pdf', + 'content': 'test content', + 'embedding': [0.1, 0.2, 0.3], + 'file_size': 1000 + } + } + ] + } + } + + # Mock calculate_document_embedding to return None + with patch('backend.utils.document_vector_utils.calculate_document_embedding') as mock_calc: + mock_calc.return_value = None + + docs, embeddings = process_documents_for_clustering('test_index', mock_es_core) + assert isinstance(docs, dict) + assert isinstance(embeddings, dict) + assert len(embeddings) == 0 # No successful embeddings + + def test_extract_cluster_content_missing_doc(self): + """Test extract_cluster_content with missing document""" + document_samples = { + 'doc1': { + 'chunks': [{'content': 'test content'}] + } + } + cluster_doc_ids = ['doc1', 'missing_doc'] + + result = extract_cluster_content(document_samples, cluster_doc_ids) + assert isinstance(result, str) + assert 'test content' in result + + def test_extract_cluster_content_no_chunks(self): + """Test extract_cluster_content with document having no chunks""" + document_samples = { + 'doc1': { + 'chunks': [] + } + } + cluster_doc_ids = ['doc1'] + + result = extract_cluster_content(document_samples, cluster_doc_ids) + assert isinstance(result, str) + + def test_extract_representative_chunks_smart_import_error(self): + """Test extract_representative_chunks_smart with ImportError""" + chunks = [ + {'content': 'chunk 1'}, + {'content': 'chunk 2'}, + {'content': 'chunk 3'} + ] + + # Mock the import to raise ImportError + with patch('builtins.__import__', side_effect=ImportError("Module not found")): + result = extract_representative_chunks_smart(chunks, max_chunks=2) + assert len(result) <= 2 + assert len(result) > 0 + + def test_extract_representative_chunks_smart_short_content(self): + """Test extract_representative_chunks_smart with short content""" + chunks = [ + {'content': 'short'}, + {'content': 'also short'}, + {'content': 'very short content'} + ] + + result = extract_representative_chunks_smart(chunks, max_chunks=2) + assert len(result) <= 2 + assert len(result) > 0 + + def test_analyze_cluster_coherence_empty(self): + """Test analyze_cluster_coherence with empty cluster_doc_ids""" + document_samples = { + 'doc1': { + 'chunks': [{'content': 'test content'}] + } + } + cluster_doc_ids = [] + + result = analyze_cluster_coherence(cluster_doc_ids, document_samples) + assert result == {} + + def test_analyze_cluster_coherence_missing_doc(self): + """Test analyze_cluster_coherence with missing document""" + document_samples = { + 'doc1': { + 'chunks': [{'content': 'test content'}] + } + } + cluster_doc_ids = ['doc1', 'missing_doc'] + + result = analyze_cluster_coherence(cluster_doc_ids, document_samples) + assert isinstance(result, dict) + + def test_analyze_cluster_coherence_no_chunks(self): + """Test analyze_cluster_coherence with document having no chunks""" + document_samples = { + 'doc1': { + 'chunks': [] + } + } + cluster_doc_ids = ['doc1'] + + result = analyze_cluster_coherence(cluster_doc_ids, document_samples) + assert isinstance(result, dict) + + def test_summarize_clusters_map_reduce_missing_doc(self): + """Test summarize_clusters_map_reduce with missing document""" + document_samples = { + 'doc1': { + 'chunks': [{'content': 'test content'}], + 'filename': 'test.pdf' + } + } + clusters = {0: ['doc1', 'missing_doc']} + + with patch('backend.utils.document_vector_utils.summarize_document') as mock_sum_doc: + mock_sum_doc.return_value = "Doc summary" + + with patch('backend.utils.document_vector_utils.summarize_cluster') as mock_sum_cluster: + mock_sum_cluster.return_value = "Cluster summary" + + result = summarize_clusters_map_reduce(document_samples, clusters) + assert isinstance(result, dict) + assert 0 in result + + def test_summarize_clusters_map_reduce_few_chunks(self): + """Test summarize_clusters_map_reduce with document having few chunks""" + document_samples = { + 'doc1': { + 'chunks': [ + {'content': 'chunk 1'}, + {'content': 'chunk 2'} + ], + 'filename': 'test.pdf' + } + } + clusters = {0: ['doc1']} + + with patch('backend.utils.document_vector_utils.summarize_document') as mock_sum_doc: + mock_sum_doc.return_value = "Doc summary" + + with patch('backend.utils.document_vector_utils.summarize_cluster') as mock_sum_cluster: + mock_sum_cluster.return_value = "Cluster summary" + + result = summarize_clusters_map_reduce(document_samples, clusters) + assert isinstance(result, dict) + assert 0 in result + + def test_summarize_clusters_map_reduce_long_content(self): + """Test summarize_clusters_map_reduce with long content""" + long_content = 'x' * 1500 # Longer than 1000 chars + document_samples = { + 'doc1': { + 'chunks': [ + {'content': long_content} + ], + 'filename': 'test.pdf' + } + } + clusters = {0: ['doc1']} + + with patch('backend.utils.document_vector_utils.summarize_document') as mock_sum_doc: + mock_sum_doc.return_value = "Doc summary" + + with patch('backend.utils.document_vector_utils.summarize_cluster') as mock_sum_cluster: + mock_sum_cluster.return_value = "Cluster summary" + + result = summarize_clusters_map_reduce(document_samples, clusters) + assert isinstance(result, dict) + assert 0 in result + + def test_summarize_clusters_map_reduce_no_valid_docs(self): + """Test summarize_clusters_map_reduce with no valid document summaries""" + document_samples = { + 'doc1': { + 'chunks': [{'content': 'test content'}], + 'filename': 'test.pdf' + } + } + clusters = {0: ['doc1']} + + with patch('backend.utils.document_vector_utils.summarize_document') as mock_sum_doc: + mock_sum_doc.return_value = "" # Empty summary + + with patch('backend.utils.document_vector_utils.summarize_cluster') as mock_sum_cluster: + mock_sum_cluster.return_value = "Cluster summary" + + result = summarize_clusters_map_reduce(document_samples, clusters) + assert isinstance(result, dict) + assert 0 in result + + def test_summarize_cluster_legacy_exception(self): + """Test summarize_cluster_legacy with exception""" + cluster_content = "Test cluster content" + + # Mock file operations to raise exception + with patch('builtins.open', side_effect=Exception("File error")): + result = summarize_cluster_legacy(cluster_content) + assert "Failed to generate summary" in result + diff --git a/test/backend/test_llm_integration.py b/test/backend/test_llm_integration.py new file mode 100644 index 00000000..dfd62539 --- /dev/null +++ b/test/backend/test_llm_integration.py @@ -0,0 +1,98 @@ +""" +Test LLM integration for knowledge base summarization +""" + +import pytest +import sys +import os + +# Add backend to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'backend')) + +from utils.document_vector_utils import summarize_document, summarize_cluster + + +class TestLLMIntegration: + """Test LLM integration functionality""" + + def test_summarize_document_without_llm(self): + """Test document summarization without LLM (fallback mode)""" + content = "This is a test document with some content about machine learning and AI." + filename = "test_doc.txt" + + result = summarize_document(content, filename, language="zh", max_words=50) + + # Should return placeholder when no model_id/tenant_id provided + assert "[Document Summary: test_doc.txt]" in result + assert "max 50 words" in result + assert "Content:" in result + + def test_summarize_document_with_llm_params_no_config(self): + """Test document summarization with LLM parameters but no model config""" + content = "This is a test document with some content about machine learning and AI." + filename = "test_doc.txt" + + # Test with model_id and tenant_id but no actual LLM call (will fail due to missing config) + result = summarize_document( + content, filename, language="zh", max_words=50, + model_id=1, tenant_id="test_tenant" + ) + + # Should return error message when model config not found + assert "Failed to generate summary" in result or "No model configuration found" in result + + def test_summarize_cluster_without_llm(self): + """Test cluster summarization without LLM (fallback mode)""" + document_summaries = [ + "Document 1 is about machine learning algorithms.", + "Document 2 discusses neural networks and deep learning.", + "Document 3 covers AI applications in healthcare." + ] + + result = summarize_cluster(document_summaries, language="zh", max_words=100) + + # Should return placeholder when no model_id/tenant_id provided + assert "[Cluster Summary]" in result + assert "max 100 words" in result + assert "Based on 3 documents" in result + + def test_summarize_cluster_with_llm_params_no_config(self): + """Test cluster summarization with LLM parameters but no model config""" + document_summaries = [ + "Document 1 is about machine learning algorithms.", + "Document 2 discusses neural networks and deep learning." + ] + + result = summarize_cluster( + document_summaries, language="zh", max_words=100, + model_id=1, tenant_id="test_tenant" + ) + + # Should return error message when model config not found + assert "Failed to generate summary" in result or "No model configuration found" in result + + def test_summarize_document_english(self): + """Test document summarization in English""" + content = "This is a test document with some content about machine learning and AI." + filename = "test_doc.txt" + + result = summarize_document(content, filename, language="en", max_words=50) + + # Should return placeholder when no model_id/tenant_id provided + assert "[Document Summary: test_doc.txt]" in result + assert "max 50 words" in result + assert "Content:" in result + + def test_summarize_cluster_english(self): + """Test cluster summarization in English""" + document_summaries = [ + "Document 1 is about machine learning algorithms.", + "Document 2 discusses neural networks and deep learning." + ] + + result = summarize_cluster(document_summaries, language="en", max_words=100) + + # Should return placeholder when no model_id/tenant_id provided + assert "[Cluster Summary]" in result + assert "max 100 words" in result + assert "Based on 2 documents" in result diff --git a/test/backend/test_summary_formatting.py b/test/backend/test_summary_formatting.py new file mode 100644 index 00000000..31b656e5 --- /dev/null +++ b/test/backend/test_summary_formatting.py @@ -0,0 +1,77 @@ +""" +Test summary formatting and display +""" + +import pytest +import sys +import os + +# Add backend to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'backend')) + +from utils.document_vector_utils import merge_cluster_summaries + + +class TestSummaryFormatting: + """Test summary formatting functionality""" + + def test_merge_cluster_summaries_with_html_separators(self): + """Test that cluster summaries are properly wrapped in HTML paragraph tags""" + cluster_summaries = { + 0: "这是第一个簇的总结,包含关于机器学习和人工智能的内容。", + 1: "这是第二个簇的总结,包含关于深度学习和神经网络的内容。", + 2: "这是第三个簇的总结,包含关于自然语言处理的内容。" + } + + result = merge_cluster_summaries(cluster_summaries) + + # Should contain HTML paragraph tags + assert "

    " in result + assert "

    " in result + assert result.count("

    ") == 3 # Should have 3 paragraph tags for 3 clusters + + # Should contain all cluster summaries + assert "第一个簇的总结" in result + assert "第二个簇的总结" in result + assert "第三个簇的总结" in result + + # Should be properly formatted with paragraph tags + assert "

    这是第一个簇的总结" in result + assert "

    这是第二个簇的总结" in result + assert "

    这是第三个簇的总结" in result + + def test_merge_cluster_summaries_single_cluster(self): + """Test merging with single cluster (wrapped in paragraph tag)""" + cluster_summaries = { + 0: "这是唯一的簇总结。" + } + + result = merge_cluster_summaries(cluster_summaries) + + # Should be wrapped in paragraph tag + assert "

    " in result + assert "

    " in result + assert result == "

    这是唯一的簇总结。

    " + + def test_merge_cluster_summaries_empty(self): + """Test merging with empty input""" + result = merge_cluster_summaries({}) + assert result == "" + + def test_merge_cluster_summaries_order(self): + """Test that clusters are merged in correct order""" + cluster_summaries = { + 2: "第三个簇", + 0: "第一个簇", + 1: "第二个簇" + } + + result = merge_cluster_summaries(cluster_summaries) + + # Should be in cluster ID order + lines = result.split('\n') + content_lines = [line for line in lines if line.strip() and '

    ' in line] + + assert "第一个簇" in content_lines[0] + assert "第二个簇" in content_lines[1] + assert "第三个簇" in content_lines[2] diff --git a/test/sdk/vector_database/test_elasticsearch_core.py b/test/sdk/vector_database/test_elasticsearch_core.py index eedeada9..0d495c4f 100644 --- a/test/sdk/vector_database/test_elasticsearch_core.py +++ b/test/sdk/vector_database/test_elasticsearch_core.py @@ -287,3 +287,562 @@ def test_preprocess_documents_with_zero_values(elasticsearch_core_instance): assert doc["create_time"] == "2025-01-15T10:30:00" assert doc["date"] == "2025-01-15" assert doc["process_source"] == "CustomProcessor" + + +def test_preprocess_large_batch_of_documents(elasticsearch_core_instance): + """Test preprocessing a large batch of documents (100+ chunks scenario).""" + # Simulate processing a large file that generates 150 chunks + large_docs = [ + { + "content": f"Chunk content number {i}", + "title": f"Document chunk {i}", + "filename": "large_document.pdf", + "path_or_url": "/path/to/large_document.pdf" + } + for i in range(150) + ] + content_field = "content" + + with patch('time.strftime') as mock_strftime, \ + patch('time.time') as mock_time, \ + patch('time.gmtime') as mock_gmtime: + + mock_strftime.side_effect = lambda fmt, t: "2025-01-15T10:30:00" if "T" in fmt else "2025-01-15" + mock_time.return_value = 1642234567 + mock_gmtime.return_value = None + + result = elasticsearch_core_instance._preprocess_documents( + large_docs, content_field) + + # Should process all 150 documents + assert len(result) == 150 + + # Verify each document has required fields + for i, doc in enumerate(result): + assert doc["content"] == f"Chunk content number {i}" + assert doc["title"] == f"Document chunk {i}" + assert doc["filename"] == "large_document.pdf" + assert doc["path_or_url"] == "/path/to/large_document.pdf" + assert "create_time" in doc + assert "date" in doc + assert "file_size" in doc + assert "process_source" in doc + assert "id" in doc + + +def test_preprocess_documents_performance_with_large_batch(elasticsearch_core_instance): + """Test that preprocessing performance is acceptable for large batches.""" + import time as time_module + + # Create 200 documents to test performance + large_docs = [ + { + "content": f"Content {i}" * 100, # Longer content + "title": f"Title {i}", + "filename": f"file_{i}.txt" + } + for i in range(200) + ] + content_field = "content" + + with patch('time.strftime') as mock_strftime, \ + patch('time.time') as mock_time, \ + patch('time.gmtime') as mock_gmtime: + + mock_strftime.side_effect = lambda fmt, t: "2025-01-15T10:30:00" if "T" in fmt else "2025-01-15" + mock_time.return_value = 1642234567 + mock_gmtime.return_value = None + + start = time_module.time() + result = elasticsearch_core_instance._preprocess_documents( + large_docs, content_field) + elapsed = time_module.time() - start + + # Should complete in reasonable time (< 5 seconds for 200 docs) + assert elapsed < 5.0 + + # All documents should be processed + assert len(result) == 200 + + +def test_preprocess_documents_maintains_order(elasticsearch_core_instance): + """Test that document order is preserved during preprocessing.""" + docs = [ + {"content": f"Content {i}", "sequence": i} + for i in range(50) + ] + content_field = "content" + + with patch('time.strftime') as mock_strftime, \ + patch('time.time') as mock_time, \ + patch('time.gmtime') as mock_gmtime: + + mock_strftime.side_effect = lambda fmt, t: "2025-01-15T10:30:00" if "T" in fmt else "2025-01-15" + mock_time.return_value = 1642234567 + mock_gmtime.return_value = None + + result = elasticsearch_core_instance._preprocess_documents( + docs, content_field) + + # Verify order is maintained + for i, doc in enumerate(result): + assert doc["sequence"] == i + assert doc["content"] == f"Content {i}" + + +# ---------------------------------------------------------------------------- +# Tests for index management methods +# ---------------------------------------------------------------------------- + +def test_create_vector_index_success(elasticsearch_core_instance): + """Test creating a new vector index successfully.""" + with patch.object(elasticsearch_core_instance.client.indices, 'exists') as mock_exists, \ + patch.object(elasticsearch_core_instance.client.indices, 'create') as mock_create, \ + patch.object(elasticsearch_core_instance, '_force_refresh_with_retry') as mock_refresh, \ + patch.object(elasticsearch_core_instance, '_ensure_index_ready') as mock_ready: + + mock_exists.return_value = False + mock_create.return_value = {"acknowledged": True} + mock_refresh.return_value = True + mock_ready.return_value = True + + result = elasticsearch_core_instance.create_vector_index( + "test_index", embedding_dim=1024) + + assert result is True + mock_exists.assert_called_once_with(index="test_index") + mock_create.assert_called_once() + mock_refresh.assert_called_once_with("test_index") + mock_ready.assert_called_once_with("test_index") + + +def test_create_vector_index_already_exists(elasticsearch_core_instance): + """Test creating an index that already exists.""" + with patch.object(elasticsearch_core_instance.client.indices, 'exists') as mock_exists, \ + patch.object(elasticsearch_core_instance, '_ensure_index_ready') as mock_ready: + + mock_exists.return_value = True + mock_ready.return_value = True + + result = elasticsearch_core_instance.create_vector_index( + "existing_index") + + assert result is True + mock_exists.assert_called_once_with(index="existing_index") + mock_ready.assert_called_once_with("existing_index") + + +def test_delete_index_success(elasticsearch_core_instance): + """Test deleting an index successfully.""" + with patch.object(elasticsearch_core_instance.client.indices, 'delete') as mock_delete: + mock_delete.return_value = {"acknowledged": True} + + result = elasticsearch_core_instance.delete_index("test_index") + + assert result is True + mock_delete.assert_called_once_with(index="test_index") + + +def test_delete_index_not_found(elasticsearch_core_instance): + """Test deleting an index that doesn't exist.""" + from elasticsearch import exceptions + + with patch.object(elasticsearch_core_instance.client.indices, 'delete') as mock_delete: + mock_delete.side_effect = exceptions.NotFoundError( + "Index not found", {}, {}) + + result = elasticsearch_core_instance.delete_index("nonexistent_index") + + assert result is False + mock_delete.assert_called_once_with(index="nonexistent_index") + + +def test_get_user_indices_success(elasticsearch_core_instance): + """Test getting user indices successfully.""" + with patch.object(elasticsearch_core_instance.client.indices, 'get_alias') as mock_get_alias: + mock_get_alias.return_value = { + "user_index_1": {}, + "user_index_2": {}, + ".system_index": {} + } + + result = elasticsearch_core_instance.get_user_indices() + + assert len(result) == 2 + assert "user_index_1" in result + assert "user_index_2" in result + assert ".system_index" not in result + + +# ---------------------------------------------------------------------------- +# Tests for document operations +# ---------------------------------------------------------------------------- + +def test_index_documents_empty_list(elasticsearch_core_instance): + """Test indexing an empty list of documents.""" + mock_embedding_model = MagicMock() + + result = elasticsearch_core_instance.index_documents( + "test_index", + mock_embedding_model, + [], + content_field="content" + ) + + assert result == 0 + + +def test_index_documents_small_batch(elasticsearch_core_instance): + """Test indexing a small batch of documents (< 64).""" + mock_embedding_model = MagicMock() + mock_embedding_model.get_embeddings.return_value = [[0.1] * 1024] * 3 + mock_embedding_model.embedding_model_name = "test-model" + + documents = [ + {"content": "Test content 1", "title": "Test 1"}, + {"content": "Test content 2", "title": "Test 2"}, + {"content": "Test content 3", "title": "Test 3"} + ] + + with patch.object(elasticsearch_core_instance.client, 'bulk') as mock_bulk, \ + patch('time.strftime') as mock_strftime, \ + patch('time.time') as mock_time: + + mock_strftime.side_effect = lambda fmt, t: "2025-01-15T10:30:00" if "T" in fmt else "2025-01-15" + mock_time.return_value = 1642234567 + mock_bulk.return_value = {"errors": False, "items": []} + + result = elasticsearch_core_instance.index_documents( + "test_index", + mock_embedding_model, + documents, + content_field="content" + ) + + assert result == 3 + mock_embedding_model.get_embeddings.assert_called_once() + mock_bulk.assert_called_once() + + +def test_index_documents_large_batch(elasticsearch_core_instance): + """Test indexing a large batch of documents (>= 64).""" + mock_embedding_model = MagicMock() + mock_embedding_model.get_embeddings.return_value = [[0.1] * 1024] * 64 + mock_embedding_model.embedding_model_name = "test-model" + + documents = [ + {"content": f"Test content {i}", "title": f"Test {i}"} + for i in range(100) + ] + + with patch.object(elasticsearch_core_instance.client, 'bulk') as mock_bulk, \ + patch.object(elasticsearch_core_instance, '_force_refresh_with_retry') as mock_refresh, \ + patch('time.strftime') as mock_strftime, \ + patch('time.time') as mock_time, \ + patch('time.sleep'): + + mock_strftime.side_effect = lambda fmt, t: "2025-01-15T10:30:00" if "T" in fmt else "2025-01-15" + mock_time.return_value = 1642234567 + mock_bulk.return_value = {"errors": False, "items": []} + mock_refresh.return_value = True + + result = elasticsearch_core_instance.index_documents( + "test_index", + mock_embedding_model, + documents, + batch_size=64, + content_field="content" + ) + + assert result == 100 + assert mock_embedding_model.get_embeddings.call_count >= 2 + mock_bulk.assert_called() + mock_refresh.assert_called_once_with("test_index") + + +def test_delete_documents_by_path_or_url_success(elasticsearch_core_instance): + """Test deleting documents by path_or_url successfully.""" + with patch.object(elasticsearch_core_instance.client, 'delete_by_query') as mock_delete: + mock_delete.return_value = {"deleted": 5} + + result = elasticsearch_core_instance.delete_documents_by_path_or_url( + "test_index", + "/path/to/file.pdf" + ) + + assert result == 5 + mock_delete.assert_called_once() + + +# ---------------------------------------------------------------------------- +# Tests for search operations +# ---------------------------------------------------------------------------- + +def test_accurate_search_success(elasticsearch_core_instance): + """Test accurate search with text matching.""" + with patch.object(elasticsearch_core_instance, 'exec_query') as mock_exec, \ + patch('sdk.nexent.vector_database.elasticsearch_core.calculate_term_weights') as mock_weights, \ + patch('sdk.nexent.vector_database.elasticsearch_core.build_weighted_query') as mock_build: + + mock_weights.return_value = {"test": 1.0} + mock_build.return_value = { + "query": {"match": {"content": "test query"}}} + mock_exec.return_value = [ + { + "score": 10.5, + "document": {"content": "Test document", "title": "Test"}, + "index": "test_index" + } + ] + + result = elasticsearch_core_instance.accurate_search( + ["test_index"], + "test query", + top_k=5 + ) + + assert len(result) == 1 + assert result[0]["score"] == 10.5 + mock_weights.assert_called_once_with("test query") + mock_build.assert_called_once_with("test query", {"test": 1.0}) + mock_exec.assert_called_once() + + +def test_semantic_search_success(elasticsearch_core_instance): + """Test semantic search with vector similarity.""" + mock_embedding_model = MagicMock() + mock_embedding_model.get_embeddings.return_value = [[0.1] * 1024] + + with patch.object(elasticsearch_core_instance, 'exec_query') as mock_exec: + mock_exec.return_value = [ + { + "score": 0.95, + "document": {"content": "Similar document", "title": "Doc"}, + "index": "test_index" + } + ] + + result = elasticsearch_core_instance.semantic_search( + ["test_index"], + "test query", + mock_embedding_model, + top_k=5 + ) + + assert len(result) == 1 + assert result[0]["score"] == 0.95 + mock_embedding_model.get_embeddings.assert_called_once_with( + "test query") + mock_exec.assert_called_once() + + +def test_hybrid_search_success(elasticsearch_core_instance): + """Test hybrid search combining accurate and semantic results.""" + mock_embedding_model = MagicMock() + + with patch.object(elasticsearch_core_instance, 'accurate_search') as mock_accurate, \ + patch.object(elasticsearch_core_instance, 'semantic_search') as mock_semantic: + + mock_accurate.return_value = [ + { + "score": 10.0, + "document": {"id": "doc1", "content": "Test doc 1"}, + "index": "test_index" + } + ] + + mock_semantic.return_value = [ + { + "score": 0.9, + "document": {"id": "doc1", "content": "Test doc 1"}, + "index": "test_index" + }, + { + "score": 0.8, + "document": {"id": "doc2", "content": "Test doc 2"}, + "index": "test_index" + } + ] + + result = elasticsearch_core_instance.hybrid_search( + ["test_index"], + "test query", + mock_embedding_model, + top_k=5, + weight_accurate=0.3 + ) + + assert len(result) == 2 + assert all("score" in r for r in result) + assert all("document" in r for r in result) + mock_accurate.assert_called_once() + mock_semantic.assert_called_once() + + +# ---------------------------------------------------------------------------- +# Tests for statistics and monitoring +# ---------------------------------------------------------------------------- + +def test_get_file_list_with_details_success(elasticsearch_core_instance): + """Test getting file list with details.""" + with patch.object(elasticsearch_core_instance.client, 'search') as mock_search: + mock_search.return_value = { + "aggregations": { + "unique_sources": { + "buckets": [ + { + "file_sample": { + "hits": { + "hits": [ + { + "_source": { + "path_or_url": "/path/to/file1.pdf", + "filename": "file1.pdf", + "file_size": 1024, + "create_time": "2025-01-15T10:30:00" + } + } + ] + } + } + } + ] + } + } + } + + result = elasticsearch_core_instance.get_file_list_with_details( + "test_index") + + assert len(result) == 1 + assert result[0]["path_or_url"] == "/path/to/file1.pdf" + assert result[0]["filename"] == "file1.pdf" + assert result[0]["file_size"] == 1024 + mock_search.assert_called_once() + + +def test_get_index_mapping_success(elasticsearch_core_instance): + """Test getting index mapping.""" + with patch.object(elasticsearch_core_instance.client.indices, 'get_mapping') as mock_get_mapping: + mock_get_mapping.return_value = { + "test_index": { + "mappings": { + "properties": { + "content": {"type": "text"}, + "embedding": {"type": "dense_vector"} + } + } + } + } + + result = elasticsearch_core_instance.get_index_mapping(["test_index"]) + + assert "test_index" in result + assert "content" in result["test_index"] + assert "embedding" in result["test_index"] + mock_get_mapping.assert_called_once() + + +def test_get_index_stats_success(elasticsearch_core_instance): + """Test getting index statistics.""" + with patch.object(elasticsearch_core_instance.client.indices, 'stats') as mock_stats, \ + patch.object(elasticsearch_core_instance.client.indices, 'get_settings') as mock_settings, \ + patch.object(elasticsearch_core_instance.client, 'search') as mock_search: + + mock_stats.return_value = { + "indices": { + "test_index": { + "primaries": { + "docs": {"count": 100}, + "store": {"size_in_bytes": 1024000}, + "search": {"query_total": 50}, + "request_cache": {"hit_count": 25} + } + } + } + } + + mock_settings.return_value = { + "test_index": { + "settings": { + "index": { + "creation_date": "1642234567000" + } + } + } + } + + mock_search.return_value = { + "aggregations": { + "unique_path_or_url_count": {"value": 10}, + "process_sources": {"buckets": [{"key": "Unstructured"}]}, + "embedding_models": {"buckets": [{"key": "test-model"}]} + } + } + + result = elasticsearch_core_instance.get_index_stats( + ["test_index"], embedding_dim=1024) + + assert "test_index" in result + assert result["test_index"]["base_info"]["doc_count"] == 10 + assert result["test_index"]["base_info"]["chunk_count"] == 100 + mock_stats.assert_called_once() + mock_settings.assert_called_once() + mock_search.assert_called_once() + + +# ---------------------------------------------------------------------------- +# Tests for error handling +# ---------------------------------------------------------------------------- + +def test_handle_bulk_errors_with_errors(elasticsearch_core_instance): + """Test handling bulk operation errors.""" + response = { + "errors": True, + "items": [ + { + "index": { + "error": { + "type": "mapper_parsing_exception", + "reason": "Failed to parse mapping" + } + } + } + ] + } + + # Should not raise exception, just log errors + elasticsearch_core_instance._handle_bulk_errors(response) + + +def test_handle_bulk_errors_version_conflict(elasticsearch_core_instance): + """Test handling version conflict errors (should be ignored).""" + response = { + "errors": True, + "items": [ + { + "index": { + "error": { + "type": "version_conflict_engine_exception", + "reason": "Version conflict" + } + } + } + ] + } + + # Should not raise exception or log error for version conflicts + elasticsearch_core_instance._handle_bulk_errors(response) + + +def test_bulk_operation_context(elasticsearch_core_instance): + """Test bulk operation context manager.""" + with patch.object(elasticsearch_core_instance, '_apply_bulk_settings') as mock_apply, \ + patch.object(elasticsearch_core_instance, '_restore_normal_settings') as mock_restore: + + with elasticsearch_core_instance.bulk_operation_context("test_index", estimated_duration=60) as operation_id: + assert operation_id is not None + assert "bulk_" in operation_id + + mock_apply.assert_called_once_with("test_index") + mock_restore.assert_called_once_with("test_index") diff --git a/test/sdk/vector_database/test_elasticsearch_core_coverage.py b/test/sdk/vector_database/test_elasticsearch_core_coverage.py new file mode 100644 index 00000000..5bc28bc4 --- /dev/null +++ b/test/sdk/vector_database/test_elasticsearch_core_coverage.py @@ -0,0 +1,495 @@ +""" +Supplementary test module for elasticsearch_core to improve code coverage + +Tests for functions not fully covered in the main test file. +""" +import pytest +from unittest.mock import MagicMock, patch, mock_open +import time +import os +import sys +from typing import List, Dict, Any + +# Add the project root to the path +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.abspath(os.path.join(current_dir, "../../..")) +sys.path.insert(0, project_root) + +# Import the class under test +from sdk.nexent.vector_database.elasticsearch_core import ElasticSearchCore +from elasticsearch import exceptions + + +class TestElasticSearchCoreCoverage: + """Test class for improving elasticsearch_core coverage""" + + @pytest.fixture + def es_core(self): + """Create an ElasticSearchCore instance for testing.""" + return ElasticSearchCore( + host="http://localhost:9200", + api_key="test_api_key", + verify_certs=False, + ssl_show_warn=False + ) + + def test_force_refresh_with_retry_success(self, es_core): + """Test _force_refresh_with_retry successful refresh""" + es_core.client = MagicMock() + es_core.client.indices.refresh.return_value = {"_shards": {"total": 1, "successful": 1}} + + result = es_core._force_refresh_with_retry("test_index") + assert result is True + es_core.client.indices.refresh.assert_called_once_with(index="test_index") + + def test_force_refresh_with_retry_failure_retry(self, es_core): + """Test _force_refresh_with_retry with retries""" + es_core.client = MagicMock() + es_core.client.indices.refresh.side_effect = [ + Exception("Connection error"), + Exception("Still failing"), + {"_shards": {"total": 1, "successful": 1}} + ] + + with patch('time.sleep'): # Mock sleep to speed up test + result = es_core._force_refresh_with_retry("test_index", max_retries=3) + assert result is True + assert es_core.client.indices.refresh.call_count == 3 + + def test_force_refresh_with_retry_max_retries_exceeded(self, es_core): + """Test _force_refresh_with_retry when max retries exceeded""" + es_core.client = MagicMock() + es_core.client.indices.refresh.side_effect = Exception("Persistent error") + + with patch('time.sleep'): # Mock sleep to speed up test + result = es_core._force_refresh_with_retry("test_index", max_retries=2) + assert result is False + assert es_core.client.indices.refresh.call_count == 2 + + def test_ensure_index_ready_success(self, es_core): + """Test _ensure_index_ready successful case""" + es_core.client = MagicMock() + es_core.client.cluster.health.return_value = {"status": "green"} + es_core.client.search.return_value = {"hits": {"total": {"value": 0}}} + + result = es_core._ensure_index_ready("test_index") + assert result is True + + def test_ensure_index_ready_yellow_status(self, es_core): + """Test _ensure_index_ready with yellow status""" + es_core.client = MagicMock() + es_core.client.cluster.health.return_value = {"status": "yellow"} + es_core.client.search.return_value = {"hits": {"total": {"value": 0}}} + + result = es_core._ensure_index_ready("test_index") + assert result is True + + def test_ensure_index_ready_timeout(self, es_core): + """Test _ensure_index_ready timeout scenario""" + es_core.client = MagicMock() + es_core.client.cluster.health.return_value = {"status": "red"} + + with patch('time.sleep'): # Mock sleep to speed up test + result = es_core._ensure_index_ready("test_index", timeout=1) + assert result is False + + def test_ensure_index_ready_exception(self, es_core): + """Test _ensure_index_ready with exception""" + es_core.client = MagicMock() + es_core.client.cluster.health.side_effect = Exception("Connection error") + + with patch('time.sleep'): # Mock sleep to speed up test + result = es_core._ensure_index_ready("test_index", timeout=1) + assert result is False + + def test_apply_bulk_settings_success(self, es_core): + """Test _apply_bulk_settings successful case""" + es_core.client = MagicMock() + es_core.client.indices.put_settings.return_value = {"acknowledged": True} + + es_core._apply_bulk_settings("test_index") + es_core.client.indices.put_settings.assert_called_once() + + def test_apply_bulk_settings_failure(self, es_core): + """Test _apply_bulk_settings with exception""" + es_core.client = MagicMock() + es_core.client.indices.put_settings.side_effect = Exception("Settings error") + + # Should not raise exception, just log warning + es_core._apply_bulk_settings("test_index") + es_core.client.indices.put_settings.assert_called_once() + + def test_restore_normal_settings_success(self, es_core): + """Test _restore_normal_settings successful case""" + es_core.client = MagicMock() + es_core.client.indices.put_settings.return_value = {"acknowledged": True} + es_core._force_refresh_with_retry = MagicMock(return_value=True) + + es_core._restore_normal_settings("test_index") + es_core.client.indices.put_settings.assert_called_once() + es_core._force_refresh_with_retry.assert_called_once_with("test_index") + + def test_restore_normal_settings_failure(self, es_core): + """Test _restore_normal_settings with exception""" + es_core.client = MagicMock() + es_core.client.indices.put_settings.side_effect = Exception("Settings error") + + # Should not raise exception, just log warning + es_core._restore_normal_settings("test_index") + es_core.client.indices.put_settings.assert_called_once() + + def test_delete_index_success(self, es_core): + """Test delete_index successful case""" + es_core.client = MagicMock() + es_core.client.indices.delete.return_value = {"acknowledged": True} + + result = es_core.delete_index("test_index") + assert result is True + es_core.client.indices.delete.assert_called_once_with(index="test_index") + + def test_delete_index_not_found(self, es_core): + """Test delete_index when index not found""" + es_core.client = MagicMock() + # Create a proper NotFoundError with required parameters + not_found_error = exceptions.NotFoundError(404, "Index not found", {"error": {"type": "index_not_found_exception"}}) + es_core.client.indices.delete.side_effect = not_found_error + + result = es_core.delete_index("test_index") + assert result is False + es_core.client.indices.delete.assert_called_once_with(index="test_index") + + def test_delete_index_general_exception(self, es_core): + """Test delete_index with general exception""" + es_core.client = MagicMock() + es_core.client.indices.delete.side_effect = Exception("General error") + + result = es_core.delete_index("test_index") + assert result is False + es_core.client.indices.delete.assert_called_once_with(index="test_index") + + def test_handle_bulk_errors_no_errors(self, es_core): + """Test _handle_bulk_errors when no errors in response""" + response = {"errors": False, "items": []} + es_core._handle_bulk_errors(response) + # Should not raise any exceptions + + def test_handle_bulk_errors_with_version_conflict(self, es_core): + """Test _handle_bulk_errors with version conflict (should be ignored)""" + response = { + "errors": True, + "items": [ + { + "index": { + "error": { + "type": "version_conflict_engine_exception", + "reason": "Document already exists", + "caused_by": { + "type": "version_conflict", + "reason": "Document version conflict" + } + } + } + } + ] + } + es_core._handle_bulk_errors(response) + # Should not raise any exceptions for version conflicts + + def test_handle_bulk_errors_with_fatal_error(self, es_core): + """Test _handle_bulk_errors with fatal error""" + response = { + "errors": True, + "items": [ + { + "index": { + "error": { + "type": "mapper_parsing_exception", + "reason": "Failed to parse field", + "caused_by": { + "type": "json_parse_exception", + "reason": "Unexpected character" + } + } + } + } + ] + } + es_core._handle_bulk_errors(response) + # Should log error but not raise exception + + def test_handle_bulk_errors_with_caused_by(self, es_core): + """Test _handle_bulk_errors with caused_by information""" + response = { + "errors": True, + "items": [ + { + "index": { + "error": { + "type": "illegal_argument_exception", + "reason": "Invalid argument", + "caused_by": { + "type": "json_parse_exception", + "reason": "JSON parsing failed" + } + } + } + } + ] + } + es_core._handle_bulk_errors(response) + # Should log both main error and caused_by error + + def test_delete_documents_by_path_or_url_success(self, es_core): + """Test delete_documents_by_path_or_url successful case""" + es_core.client = MagicMock() + es_core.client.delete_by_query.return_value = {"deleted": 5} + + result = es_core.delete_documents_by_path_or_url("test_index", "/path/to/file.pdf") + assert result == 5 + es_core.client.delete_by_query.assert_called_once() + + def test_delete_documents_by_path_or_url_exception(self, es_core): + """Test delete_documents_by_path_or_url with exception""" + es_core.client = MagicMock() + es_core.client.delete_by_query.side_effect = Exception("Delete error") + + result = es_core.delete_documents_by_path_or_url("test_index", "/path/to/file.pdf") + assert result == 0 + es_core.client.delete_by_query.assert_called_once() + + def test_get_index_mapping_success(self, es_core): + """Test get_index_mapping successful case""" + es_core.client = MagicMock() + es_core.client.indices.get_mapping.return_value = { + "test_index": { + "mappings": { + "properties": { + "title": {"type": "text"}, + "content": {"type": "text"} + } + } + } + } + + result = es_core.get_index_mapping(["test_index"]) + assert "test_index" in result + assert "title" in result["test_index"] + assert "content" in result["test_index"] + + def test_get_index_mapping_exception(self, es_core): + """Test get_index_mapping with exception""" + es_core.client = MagicMock() + es_core.client.indices.get_mapping.side_effect = Exception("Mapping error") + + result = es_core.get_index_mapping(["test_index"]) + # The function returns empty list for failed indices, not empty dict + assert "test_index" in result + assert result["test_index"] == [] + + def test_get_index_stats_success(self, es_core): + """Test get_index_stats successful case""" + es_core.client = MagicMock() + es_core.client.indices.stats.return_value = { + "indices": { + "test_index": { + "primaries": { + "docs": {"count": 100}, + "store": {"size_in_bytes": 1024}, + "search": {"query_total": 50}, + "request_cache": {"hit_count": 25} + } + } + } + } + es_core.client.indices.get_settings.return_value = { + "test_index": { + "settings": { + "index": { + "number_of_shards": "1", + "number_of_replicas": "0", + "creation_date": "1640995200000" + } + } + } + } + es_core.client.search.return_value = { + "aggregations": { + "unique_path_or_url_count": {"value": 10}, + "process_sources": {"buckets": [{"key": "test_source"}]}, + "embedding_models": {"buckets": [{"key": "test_model"}]} + } + } + + result = es_core.get_index_stats(["test_index"]) + assert "test_index" in result + assert "base_info" in result["test_index"] + assert "search_performance" in result["test_index"] + + def test_get_index_stats_exception(self, es_core): + """Test get_index_stats with exception""" + es_core.client = MagicMock() + es_core.client.indices.stats.side_effect = Exception("Stats error") + + result = es_core.get_index_stats(["test_index"]) + # The function returns error info for failed indices, not empty dict + assert "test_index" in result + assert "error" in result["test_index"] + + def test_get_index_stats_with_embedding_dim(self, es_core): + """Test get_index_stats with embedding dimension""" + es_core.client = MagicMock() + es_core.client.indices.stats.return_value = { + "indices": { + "test_index": { + "primaries": { + "docs": {"count": 100}, + "store": {"size_in_bytes": 1024}, + "search": {"query_total": 50}, + "request_cache": {"hit_count": 25} + } + } + } + } + es_core.client.indices.get_settings.return_value = { + "test_index": { + "settings": { + "index": { + "number_of_shards": "1", + "number_of_replicas": "0", + "creation_date": "1640995200000" + } + } + } + } + es_core.client.search.return_value = { + "aggregations": { + "unique_path_or_url_count": {"value": 10}, + "process_sources": {"buckets": [{"key": "test_source"}]}, + "embedding_models": {"buckets": [{"key": "test_model"}]} + } + } + + result = es_core.get_index_stats(["test_index"], embedding_dim=512) + assert "test_index" in result + assert "base_info" in result["test_index"] + assert "search_performance" in result["test_index"] + assert result["test_index"]["base_info"]["embedding_dim"] == 512 + + def test_bulk_operation_context_success(self, es_core): + """Test bulk_operation_context successful case""" + es_core._bulk_operations = {} + es_core._operation_counter = 0 + es_core._settings_lock = MagicMock() + es_core._apply_bulk_settings = MagicMock() + es_core._restore_normal_settings = MagicMock() + + with es_core.bulk_operation_context("test_index") as operation_id: + assert operation_id is not None + assert "test_index" in es_core._bulk_operations + es_core._apply_bulk_settings.assert_called_once_with("test_index") + + # After context exit, should restore settings + es_core._restore_normal_settings.assert_called_once_with("test_index") + + def test_bulk_operation_context_multiple_operations(self, es_core): + """Test bulk_operation_context with multiple operations""" + es_core._bulk_operations = {} + es_core._operation_counter = 0 + es_core._settings_lock = MagicMock() + es_core._apply_bulk_settings = MagicMock() + es_core._restore_normal_settings = MagicMock() + + # First operation + with es_core.bulk_operation_context("test_index") as op1: + assert op1 is not None + es_core._apply_bulk_settings.assert_called_once() + + # After first operation exits, settings should be restored + es_core._restore_normal_settings.assert_called_once_with("test_index") + + # Second operation - will apply settings again since first operation is done + with es_core.bulk_operation_context("test_index") as op2: + assert op2 is not None + # Should call apply_bulk_settings again since first operation is done + assert es_core._apply_bulk_settings.call_count == 2 + + # After second operation exits, should restore settings again + assert es_core._restore_normal_settings.call_count == 2 + + def test_small_batch_insert_success(self, es_core): + """Test _small_batch_insert successful case""" + es_core.client = MagicMock() + es_core.client.bulk.return_value = {"items": [], "errors": False} + es_core._preprocess_documents = MagicMock(return_value=[ + {"content": "test content", "title": "test"} + ]) + es_core._handle_bulk_errors = MagicMock() + + mock_embedding_model = MagicMock() + mock_embedding_model.get_embeddings.return_value = [[0.1, 0.2, 0.3]] + mock_embedding_model.embedding_model_name = "test_model" + + documents = [{"content": "test content", "title": "test"}] + + result = es_core._small_batch_insert("test_index", documents, "content", mock_embedding_model) + assert result == 1 + es_core.client.bulk.assert_called_once() + + def test_small_batch_insert_exception(self, es_core): + """Test _small_batch_insert with exception""" + es_core._preprocess_documents = MagicMock(side_effect=Exception("Preprocess error")) + + mock_embedding_model = MagicMock() + documents = [{"content": "test content", "title": "test"}] + + result = es_core._small_batch_insert("test_index", documents, "content", mock_embedding_model) + assert result == 0 + + def test_large_batch_insert_success(self, es_core): + """Test _large_batch_insert successful case""" + es_core.client = MagicMock() + es_core.client.bulk.return_value = {"items": [], "errors": False} + es_core._preprocess_documents = MagicMock(return_value=[ + {"content": "test content", "title": "test"} + ]) + es_core._handle_bulk_errors = MagicMock() + + mock_embedding_model = MagicMock() + mock_embedding_model.get_embeddings.return_value = [[0.1, 0.2, 0.3]] + mock_embedding_model.embedding_model_name = "test_model" + + documents = [{"content": "test content", "title": "test"}] + + result = es_core._large_batch_insert("test_index", documents, 10, "content", mock_embedding_model) + assert result == 1 + es_core.client.bulk.assert_called_once() + + def test_large_batch_insert_embedding_error(self, es_core): + """Test _large_batch_insert with embedding API error""" + es_core.client = MagicMock() + es_core._preprocess_documents = MagicMock(return_value=[ + {"content": "test content", "title": "test"} + ]) + + mock_embedding_model = MagicMock() + mock_embedding_model.get_embeddings.side_effect = Exception("Embedding API error") + + documents = [{"content": "test content", "title": "test"}] + + result = es_core._large_batch_insert("test_index", documents, 10, "content", mock_embedding_model) + assert result == 0 # No documents indexed due to embedding error + + def test_large_batch_insert_no_embeddings(self, es_core): + """Test _large_batch_insert with no successful embeddings""" + es_core.client = MagicMock() + es_core._preprocess_documents = MagicMock(return_value=[ + {"content": "test content", "title": "test"} + ]) + + mock_embedding_model = MagicMock() + mock_embedding_model.get_embeddings.side_effect = Exception("Embedding API error") + + documents = [{"content": "test content", "title": "test"}] + + result = es_core._large_batch_insert("test_index", documents, 10, "content", mock_embedding_model) + assert result == 0 # No documents indexed