From d3f2afb7940e84f3d71537fc919cccd442f73efc Mon Sep 17 00:00:00 2001 From: Filip Komarzyniec Date: Wed, 24 Jun 2026 11:59:01 +0200 Subject: [PATCH 01/16] modified exposed (per rag pattern) payload of responses api requests to align with previous payload of chat/completion; relevant tests changes Signed-off-by: Filip Komarzyniec rh-pre-commit.version: 2.3.2 rh-pre-commit.check-secrets: ENABLED Signed-off-by: Lukasz Cmielowski Assisted-by: Cursor --- .../assets_generator/pattern_builder.py | 17 ++++++++++-- ai4rag/core/experiment/experiment.py | 2 ++ .../assets_generator/test_pattern_builder.py | 26 +++++++++++++++---- 3 files changed, 38 insertions(+), 7 deletions(-) diff --git a/ai4rag/components/assets_generator/pattern_builder.py b/ai4rag/components/assets_generator/pattern_builder.py index 3e66195..a686573 100644 --- a/ai4rag/components/assets_generator/pattern_builder.py +++ b/ai4rag/components/assets_generator/pattern_builder.py @@ -28,8 +28,16 @@ def build_pattern_json( "model": pattern["settings"]["generation"]["model_id"], "stream": False, "store": False, - "input": "", - "instructions": pattern["settings"]["generation"]["system_message_text"], + "input": [ + { + "content": [{"text": pattern["settings"]["generation"]["system_message_text"], "type": "input_text"}], + "role": "system", + }, + {"content": [{"text": "", "type": "input_text"}], "role": "user"}, + ], + "max_output_tokens": pattern["settings"]["generation"]["max_completion_tokens"], + "temperature": pattern["settings"]["generation"]["temperature"], + "tool_choice": {"mode": "required", "tools": [{}], "type": "file_search"}, "tools": [ { "type": "file_search", @@ -37,6 +45,7 @@ def build_pattern_json( "ranking_options": { "max_num_results": pattern["settings"]["retrieval"]["number_of_chunks"], }, + "max_num_results": pattern["settings"]["retrieval"]["number_of_chunks"], }, ], "include": ["file_search_call.results"], @@ -52,5 +61,9 @@ def build_pattern_json( pattern["settings"]["responses_template"]["tools"][0]["ranking_options"]["impact_factor"] = ranker_k elif search_mode == "hybrid" and ranker_strategy == "weighted" and ranker_alpha is not None and ranker_alpha != 1: pattern["settings"]["responses_template"]["tools"][0]["ranking_options"]["alpha"] = ranker_alpha + else: + pattern["settings"]["responses_template"]["tools"][0]["ranking_options"].update( + {"ranker": "auto", "weights": {"vector": 1.0, "neural": 0.0, "keyword": 0.0}} + ) return pattern diff --git a/ai4rag/core/experiment/experiment.py b/ai4rag/core/experiment/experiment.py index fefecc4..de92762 100644 --- a/ai4rag/core/experiment/experiment.py +++ b/ai4rag/core/experiment/experiment.py @@ -334,6 +334,8 @@ def run_single_evaluation(self, rag_params: RAGParamsType) -> float: "retrieval": retrieval_params, "generation": { "model_id": foundation_model.model_id, + "temperature": foundation_model.params.temperature, + "max_completion_tokens": foundation_model.params.max_completion_tokens, "context_template_text": context_template_text, "user_message_text": user_message_text, "system_message_text": system_message_text, diff --git a/tests/unit/ai4rag/assets_generator/test_pattern_builder.py b/tests/unit/ai4rag/assets_generator/test_pattern_builder.py index 96f5e50..ea36859 100644 --- a/tests/unit/ai4rag/assets_generator/test_pattern_builder.py +++ b/tests/unit/ai4rag/assets_generator/test_pattern_builder.py @@ -41,6 +41,8 @@ def _make_pattern(**overrides) -> dict: }, "generation": { "model_id": "ibm/granite-3-8b-instruct", + "temperature": 0.7, + "max_completion_tokens": 1024, "system_message_text": "Answer based on context only.", "user_message_text": "Context: {reference_documents}\nQ: {question}", "context_template_text": "{document}", @@ -73,12 +75,21 @@ def test_adds_responses_template(self): assert rt["model"] == "ibm/granite-3-8b-instruct" assert rt["stream"] is False assert rt["store"] is False - assert rt["input"] == "" - assert rt["instructions"] == "Answer based on context only." + assert rt["input"] == [ + { + "content": [{"text": "Answer based on context only.", "type": "input_text"}], + "role": "system", + }, + {"content": [{"text": "", "type": "input_text"}], "role": "user"}, + ] + assert rt["max_output_tokens"] == 1024 + assert rt["temperature"] == 0.7 + assert rt["tool_choice"] == {"mode": "required", "tools": [{}], "type": "file_search"} assert len(rt["tools"]) == 1 assert rt["tools"][0]["type"] == "file_search" assert "test_collection_001" in rt["tools"][0]["vector_store_ids"] assert rt["tools"][0]["ranking_options"]["max_num_results"] == 5 + assert rt["tools"][0]["max_num_results"] == 5 assert rt["include"] == ["file_search_call.results"] def test_returns_same_dict(self): @@ -126,13 +137,18 @@ def test_hybrid_weighted_ranking_options(self): assert ro["alpha"] == 0.7 assert ro["max_num_results"] == 5 - def test_simple_retrieval_has_only_max_num_results(self): - """Simple retrieval must have ranking_options with only max_num_results.""" + def test_simple_retrieval_default_ranking_options(self): + """Simple retrieval must have default ranker and weights in ranking_options.""" pattern = _make_pattern() build_pattern_json(pattern) ro = pattern["settings"]["responses_template"]["tools"][0]["ranking_options"] - assert ro == {"max_num_results": 5} + assert ro == { + "max_num_results": 5, + "ranker": "auto", + "weights": {"vector": 1.0, "neural": 0.0, "keyword": 0.0}, + } + assert pattern["settings"]["responses_template"]["tools"][0]["max_num_results"] == 5 def test_preserves_existing_pattern_fields(self): """Existing pattern fields (name, chunking, embedding, etc.) must not be altered.""" From 4cecb1fab86294104f42626f19d518967a599d49 Mon Sep 17 00:00:00 2001 From: Filip Komarzyniec Date: Wed, 24 Jun 2026 13:18:20 +0200 Subject: [PATCH 02/16] changed responses_template object according to PR review (explicit ranker, deleted neural weights and unnecessary fields) Signed-off-by: Filip Komarzyniec Signed-off-by: Filip Komarzyniec Signed-off-by: Filip Komarzyniec rh-pre-commit.version: 2.3.2 rh-pre-commit.check-secrets: ENABLED Signed-off-by: Lukasz Cmielowski Assisted-by: Cursor --- .../assets_generator/pattern_builder.py | 19 +++++++++++-------- .../assets_generator/test_pattern_builder.py | 19 ++++++------------- 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/ai4rag/components/assets_generator/pattern_builder.py b/ai4rag/components/assets_generator/pattern_builder.py index a686573..17542f9 100644 --- a/ai4rag/components/assets_generator/pattern_builder.py +++ b/ai4rag/components/assets_generator/pattern_builder.py @@ -42,9 +42,6 @@ def build_pattern_json( { "type": "file_search", "vector_store_ids": [pattern["settings"]["vector_store_binding"]["vector_store_id"]], - "ranking_options": { - "max_num_results": pattern["settings"]["retrieval"]["number_of_chunks"], - }, "max_num_results": pattern["settings"]["retrieval"]["number_of_chunks"], }, ], @@ -58,12 +55,18 @@ def build_pattern_json( ranker_alpha = retrieval_settings.get("ranker_alpha") if search_mode == "hybrid" and ranker_strategy == "rrf" and ranker_k is not None and ranker_k > 0: - pattern["settings"]["responses_template"]["tools"][0]["ranking_options"]["impact_factor"] = ranker_k + pattern["settings"]["responses_template"]["tools"][0]["ranking_options"] = { + "ranker": "rrf", + "impact_factor": ranker_k, + } elif search_mode == "hybrid" and ranker_strategy == "weighted" and ranker_alpha is not None and ranker_alpha != 1: - pattern["settings"]["responses_template"]["tools"][0]["ranking_options"]["alpha"] = ranker_alpha + pattern["settings"]["responses_template"]["tools"][0]["ranking_options"] = { + "ranker": "weighted", + "alpha": ranker_alpha, + } else: - pattern["settings"]["responses_template"]["tools"][0]["ranking_options"].update( - {"ranker": "auto", "weights": {"vector": 1.0, "neural": 0.0, "keyword": 0.0}} - ) + pattern["settings"]["responses_template"]["tools"][0]["ranking_options"] = { + "weights": {"vector": 1.0, "keyword": 0.0} + } return pattern diff --git a/tests/unit/ai4rag/assets_generator/test_pattern_builder.py b/tests/unit/ai4rag/assets_generator/test_pattern_builder.py index ea36859..77053d2 100644 --- a/tests/unit/ai4rag/assets_generator/test_pattern_builder.py +++ b/tests/unit/ai4rag/assets_generator/test_pattern_builder.py @@ -88,7 +88,6 @@ def test_adds_responses_template(self): assert len(rt["tools"]) == 1 assert rt["tools"][0]["type"] == "file_search" assert "test_collection_001" in rt["tools"][0]["vector_store_ids"] - assert rt["tools"][0]["ranking_options"]["max_num_results"] == 5 assert rt["tools"][0]["max_num_results"] == 5 assert rt["include"] == ["file_search_call.results"] @@ -112,7 +111,7 @@ def test_detected_language_injected(self): assert pattern["settings"]["generation"]["detected_language"] == lang def test_hybrid_rrf_ranking_options(self): - """Hybrid search with RRF ranker must merge impact_factor into ranking_options.""" + """Hybrid search with RRF ranker must set ranker and impact_factor in ranking_options.""" pattern = _make_pattern() pattern["settings"]["retrieval"]["search_mode"] = "hybrid" pattern["settings"]["retrieval"]["ranker_strategy"] = "rrf" @@ -121,11 +120,10 @@ def test_hybrid_rrf_ranking_options(self): build_pattern_json(pattern) ro = pattern["settings"]["responses_template"]["tools"][0]["ranking_options"] - assert ro["impact_factor"] == 60 - assert ro["max_num_results"] == 5 + assert ro == {"ranker": "rrf", "impact_factor": 60} def test_hybrid_weighted_ranking_options(self): - """Hybrid search with weighted ranker must merge alpha into ranking_options.""" + """Hybrid search with weighted ranker must set ranker and alpha in ranking_options.""" pattern = _make_pattern() pattern["settings"]["retrieval"]["search_mode"] = "hybrid" pattern["settings"]["retrieval"]["ranker_strategy"] = "weighted" @@ -134,20 +132,15 @@ def test_hybrid_weighted_ranking_options(self): build_pattern_json(pattern) ro = pattern["settings"]["responses_template"]["tools"][0]["ranking_options"] - assert ro["alpha"] == 0.7 - assert ro["max_num_results"] == 5 + assert ro == {"ranker": "weighted", "alpha": 0.7} def test_simple_retrieval_default_ranking_options(self): - """Simple retrieval must have default ranker and weights in ranking_options.""" + """Simple retrieval must have default weights in ranking_options.""" pattern = _make_pattern() build_pattern_json(pattern) ro = pattern["settings"]["responses_template"]["tools"][0]["ranking_options"] - assert ro == { - "max_num_results": 5, - "ranker": "auto", - "weights": {"vector": 1.0, "neural": 0.0, "keyword": 0.0}, - } + assert ro == {"weights": {"vector": 1.0, "keyword": 0.0}} assert pattern["settings"]["responses_template"]["tools"][0]["max_num_results"] == 5 def test_preserves_existing_pattern_fields(self): From 900371e225fa3ad7db8bb6720d46e9a5932f4890 Mon Sep 17 00:00:00 2001 From: Filip Komarzyniec Date: Wed, 24 Jun 2026 14:01:15 +0200 Subject: [PATCH 03/16] changed workaround for simulating semantic-only search in non hybrid search cases; related tests update Signed-off-by: Filip Komarzyniec Signed-off-by: Filip Komarzyniec rh-pre-commit.version: 2.3.2 rh-pre-commit.check-secrets: ENABLED Signed-off-by: Lukasz Cmielowski Assisted-by: Cursor --- ai4rag/components/assets_generator/pattern_builder.py | 4 +++- tests/unit/ai4rag/assets_generator/test_pattern_builder.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ai4rag/components/assets_generator/pattern_builder.py b/ai4rag/components/assets_generator/pattern_builder.py index 17542f9..6babc84 100644 --- a/ai4rag/components/assets_generator/pattern_builder.py +++ b/ai4rag/components/assets_generator/pattern_builder.py @@ -66,7 +66,9 @@ def build_pattern_json( } else: pattern["settings"]["responses_template"]["tools"][0]["ranking_options"] = { - "weights": {"vector": 1.0, "keyword": 0.0} + # simulate semantic-only search + "ranker": "weighted", + "alpha": 1.0, } return pattern diff --git a/tests/unit/ai4rag/assets_generator/test_pattern_builder.py b/tests/unit/ai4rag/assets_generator/test_pattern_builder.py index 77053d2..c4a0ee4 100644 --- a/tests/unit/ai4rag/assets_generator/test_pattern_builder.py +++ b/tests/unit/ai4rag/assets_generator/test_pattern_builder.py @@ -140,7 +140,7 @@ def test_simple_retrieval_default_ranking_options(self): build_pattern_json(pattern) ro = pattern["settings"]["responses_template"]["tools"][0]["ranking_options"] - assert ro == {"weights": {"vector": 1.0, "keyword": 0.0}} + assert ro == {"ranker": "weighted", "alpha": 1.0} assert pattern["settings"]["responses_template"]["tools"][0]["max_num_results"] == 5 def test_preserves_existing_pattern_fields(self): From be54757ac1ab1f24fa0c5ec36108d3ef506c96fd Mon Sep 17 00:00:00 2001 From: Jakub Walaszczyk Date: Wed, 24 Jun 2026 21:33:40 +0200 Subject: [PATCH 04/16] Downgrade docling-core version (#78) Signed-off-by: Jakub Walaszczyk Signed-off-by: Lukasz Cmielowski Assisted-by: Cursor --- pyproject.toml | 2 +- uv.lock | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 91585df..336a645 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ "boto3>=1.28", "multiprocess>=0.70", "docling~=2.88.0", - "docling-core[chunking,chunking-openai]~=2.84.0", + "docling-core[chunking,chunking-openai]~=2.83.0", "langchain~=1.1.3", "langchain_chroma~=1.1.0", "langchain-text-splitters~=1.1.0", diff --git a/uv.lock b/uv.lock index ab9df90..1000093 100644 --- a/uv.lock +++ b/uv.lock @@ -95,7 +95,7 @@ requires-dist = [ { name = "black", marker = "extra == 'code-check'" }, { name = "boto3", specifier = ">=1.28" }, { name = "docling", specifier = "~=2.88.0" }, - { name = "docling-core", extras = ["chunking", "chunking-openai"], specifier = "~=2.84.0" }, + { name = "docling-core", extras = ["chunking", "chunking-openai"], specifier = "~=2.83.0" }, { name = "dotenv", marker = "extra == 'dev'" }, { name = "ipykernel", marker = "extra == 'dev'" }, { name = "isort", marker = "extra == 'code-check'" }, @@ -853,7 +853,7 @@ wheels = [ [[package]] name = "docling-core" -version = "2.84.0" +version = "2.83.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "defusedxml" }, @@ -869,9 +869,9 @@ dependencies = [ { name = "typer" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/cb/bc/41d9d5a38502538a7cd3c7fb66a06eb676f8cc1d6d275d7fbe971352efd7/docling_core-2.84.0.tar.gz", hash = "sha256:331035d5032be683a6d66d476146652491d2c75e4562e04da3f1a8d989d74bfc", size = 317993, upload-time = "2026-06-23T09:44:42.339Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4f/04/0f62e8092dfeccf7d287b7898d8dcf793417e96f6976b41f5359bf143ebe/docling_core-2.83.1.tar.gz", hash = "sha256:e09ce91d18522bb161b45ca79b9bd2c94f0b38ca744a1f8605f3e1d9bff26902", size = 317894, upload-time = "2026-06-19T05:45:49.651Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/50/d4/a8742fdd4c8b782608005ffde77a24bbfd54602c4c49825bf04ea25363ab/docling_core-2.84.0-py3-none-any.whl", hash = "sha256:2aee881b94138234c703298066b5ed94999490323ae6a29cbe5110f9248b83dd", size = 257936, upload-time = "2026-06-23T09:44:40.506Z" }, + { url = "https://files.pythonhosted.org/packages/2b/7b/dfe63a732152097d6770a112389dbcc99f99928781d9f4ac72e29cadee91/docling_core-2.83.1-py3-none-any.whl", hash = "sha256:154db6d6be5ac0bc0affa272a39e81346dc953ea7ce5cd6f20809cc5272499c6", size = 257780, upload-time = "2026-06-19T05:45:47.866Z" }, ] [package.optional-dependencies] From 14bd1a92922d35fce74241707baa8a43d4e1b876 Mon Sep 17 00:00:00 2001 From: Jakub Walaszczyk Date: Wed, 24 Jun 2026 21:35:58 +0200 Subject: [PATCH 05/16] Release 0.8.1 Signed-off-by: Jakub Walaszczyk Assisted-by: Claude Code Signed-off-by: Lukasz Cmielowski --- ai4rag/__init__.py | 2 +- docs/about/changelog.md | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/ai4rag/__init__.py b/ai4rag/__init__.py index 8fe1577..1e6346f 100644 --- a/ai4rag/__init__.py +++ b/ai4rag/__init__.py @@ -5,7 +5,7 @@ import logging import os -__version__ = "0.8.0" +__version__ = "0.8.1" logger = logging.getLogger("ai4rag") logger.setLevel(os.getenv("LOG_LEVEL", "INFO")) diff --git a/docs/about/changelog.md b/docs/about/changelog.md index 29b34a3..5e72cb0 100644 --- a/docs/about/changelog.md +++ b/docs/about/changelog.md @@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 --- +## [0.8.1](https://github.com/IBM/ai4rag/releases/tag/v0.8.1) + +### Changed +- Downgraded `docling-core` dependency from `~=2.84.0` to `~=2.83.0` to resolve compatibility issues + +--- + ## [0.8.0](https://github.com/IBM/ai4rag/releases/tag/v0.8.0) ### Added From e3368281ba8dd8f3b9461d655c8984278822b8ff Mon Sep 17 00:00:00 2001 From: Lukasz Cmielowski Date: Thu, 25 Jun 2026 13:52:24 +0200 Subject: [PATCH 06/16] feature: Update prompt templates (#75) Signed-off-by: Jakub Walaszczyk Signed-off-by: Lukasz Cmielowski Co-authored-by: Jakub Walaszczyk Assisted-by: Claude Code Signed-off-by: Lukasz Cmielowski --- .../assets_generator/pattern_builder.py | 6 - .../rag_templates_optimization.py | 42 --- .../optimization/search_space_preparation.py | 163 ----------- ai4rag/rag/foundation_models/base_model.py | 6 +- ai4rag/rag/foundation_models/ogx.py | 2 + ai4rag/rag/foundation_models/utils.py | 30 +- ai4rag/rag/template/simple_rag_template.py | 5 +- ai4rag/search_space/src/model_props.py | 145 ++++++---- dev_utils/mocks.py | 2 + .../assets_generator/test_pattern_builder.py | 13 - .../optimization/test_language_detection.py | 256 ------------------ .../optimization/test_rag_optimization.py | 105 ------- .../optimization/test_search_space_prep.py | 26 -- .../rag/foundation_models/test_base_model.py | 18 ++ .../ai4rag/rag/foundation_models/test_ogx.py | 2 +- .../rag/foundation_models/test_utils.py | 14 + .../rag/template/test_simple_rag_template.py | 18 ++ .../search_space/src/test_model_props.py | 91 +++++++ 18 files changed, 262 insertions(+), 682 deletions(-) delete mode 100644 tests/unit/ai4rag/components/optimization/test_language_detection.py create mode 100644 tests/unit/ai4rag/search_space/src/test_model_props.py diff --git a/ai4rag/components/assets_generator/pattern_builder.py b/ai4rag/components/assets_generator/pattern_builder.py index 6babc84..9d1c848 100644 --- a/ai4rag/components/assets_generator/pattern_builder.py +++ b/ai4rag/components/assets_generator/pattern_builder.py @@ -4,7 +4,6 @@ # ----------------------------------------------------------------------------- def build_pattern_json( pattern: dict, - detected_language: dict | None = None, ) -> dict: """Update pattern information with detected language and responses template. @@ -13,17 +12,12 @@ def build_pattern_json( pattern : dict A single evaluation result object carrying ``indexing_params``, ``rag_params``, ``pattern_name``, ``collection``, etc. - detected_language : dict | None, default=None - Language detection result (``{"code": "...", "name": "..."}``). Returns ------- dict Pattern definition suitable for JSON serialisation. """ - if detected_language: - pattern["settings"]["generation"]["detected_language"] = detected_language - pattern["settings"]["responses_template"] = { "model": pattern["settings"]["generation"]["model_id"], "stream": False, diff --git a/ai4rag/components/optimization/rag_templates_optimization.py b/ai4rag/components/optimization/rag_templates_optimization.py index a967147..cccf42c 100644 --- a/ai4rag/components/optimization/rag_templates_optimization.py +++ b/ai4rag/components/optimization/rag_templates_optimization.py @@ -131,8 +131,6 @@ def run_rag_optimization( # pylint: disable=too-many-locals,too-many-arguments, with open(search_space_report_path, "r", encoding="utf-8") as f: search_space_raw = yml.safe_load(f) - detected_language: dict[str, str] | None = search_space_raw.pop("detected_language", None) - search_space = AI4RAGSearchSpace( params=[Parameter(param, "C", values=values) for param, values in search_space_raw.items()] ) @@ -147,9 +145,6 @@ def run_rag_optimization( # pylint: disable=too-many-locals,too-many-arguments, benchmark_data = pd.read_json(Path(test_data_path)) - if detected_language: - _inject_language_instructions(search_space, detected_language) - rag_exp = AI4RAGExperiment( client=ogx_client, event_handler=event_handler, @@ -179,7 +174,6 @@ def run_rag_optimization( # pylint: disable=too-many-locals,too-many-arguments, pattern_data = build_pattern_json( pattern=pattern.get("payload"), - detected_language=detected_language, ) # Generate notebooks @@ -276,42 +270,6 @@ def _evaluation_result_fallback(eval_data_list: list, evaluation_result: Any) -> return out -def _inject_language_instructions( - search_space: AI4RAGSearchSpace, - detected_language: dict[str, str], -) -> None: - """Inject explicit language response instructions into foundation models. - - When a non-English language is detected, each foundation model's system - and user messages are augmented with an instruction to respond in that - language. - - Parameters - ---------- - search_space - The search space whose foundation models will be modified in-place. - detected_language - Detection result with ``"code"`` and ``"name"`` keys. - """ - lang_code = detected_language.get("code", "") - lang_name = detected_language.get("name", "") - if not lang_name or lang_code == "en": - return - - explicit_instruction = f"You MUST respond in {lang_name}." - for fm in search_space["foundation_model"].values: - existing_sys = fm.system_message_text - existing_usr = str(fm.user_message_text) - fm.system_message_text = f"{explicit_instruction} {existing_sys}" - fm.user_message_text = f"{existing_usr}{explicit_instruction}" - - _logger.info( - "Set explicit language instruction on %d foundation model(s): %s", - len(search_space["foundation_model"].values), - explicit_instruction, - ) - - def _validate_optimization_settings(optimization_settings: dict | None) -> dict: """Validate and normalize optimization settings. diff --git a/ai4rag/components/optimization/search_space_preparation.py b/ai4rag/components/optimization/search_space_preparation.py index 2561023..f30295e 100644 --- a/ai4rag/components/optimization/search_space_preparation.py +++ b/ai4rag/components/optimization/search_space_preparation.py @@ -30,29 +30,6 @@ _DEFAULT_SAMPLE_SIZE = 5 _DEFAULT_SEED = 17 -LANGUAGE_MAP: dict[str, str] = { - "ja": "Japanese", - "ko": "Korean", - "zh-cn": "Chinese", - "zh-tw": "Chinese", - "en": "English", - "de": "German", - "fr": "French", - "es": "Spanish", - "pt": "Portuguese", - "it": "Italian", - "ru": "Russian", - "ar": "Arabic", - "hi": "Hindi", - "th": "Thai", - "vi": "Vietnamese", - "pl": "Polish", - "nl": "Dutch", - "sv": "Swedish", - "cs": "Czech", - "tr": "Turkish", -} - def _represent_model_instance(dumper: yml.Dumper, model: BaseFoundationModel | BaseEmbeddingModel) -> yml.Node: """Instruct :mod:`yaml` on how to serialize model instances under a ``!Model`` tag. @@ -94,14 +71,10 @@ class SearchSpaceReport: model lists and non-model parameter ranges. selected_models : dict[str, list] Foundation and embedding model lists that survived pre-selection. - detected_language : dict[str, str] | None - Detected language code and name, or ``None`` when English or when - detection was not performed. """ search_space: dict[str, Any] selected_models: dict[str, list] - detected_language: dict[str, str] | None def save_yaml(self, path: str | Path) -> None: """Serialize the report to a YAML file. @@ -114,8 +87,6 @@ def save_yaml(self, path: str | Path) -> None: Destination file path. """ report = dict(self.search_space) - if self.detected_language: - report["detected_language"] = self.detected_language path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) @@ -199,10 +170,6 @@ def prepare_search_space_report( # pylint: disable=too-many-locals,too-many-arg # Load benchmark data and documents benchmark_df = pd.read_json(Path(test_data_path)) - detected_language = _detect_benchmark_language( - benchmark_df, llm_client=ogx_client, generation_models=generation_models - ) - benchmark_data = BenchmarkData(benchmark_df) documents = load_docling_documents(extracted_text_path) @@ -244,7 +211,6 @@ def prepare_search_space_report( # pylint: disable=too-many-locals,too-many-arg return SearchSpaceReport( search_space=verbose_repr, selected_models=selected_models, - detected_language=detected_language, ) @@ -257,132 +223,3 @@ def _validate_model_list(models: list[str] | None, name: str) -> None: for i, m in enumerate(models): if not m: raise TypeError(f"{name}[{i}] must be a non-empty string.") - - -def _detect_language_via_llm( # pylint: disable=too-many-locals - questions: list[str], - llm_client: OgxClient, - allowed_generation_models: list[str] | None = None, -) -> dict[str, str] | None: - """Detect the dominant language from sample questions using an LLM. - - Sends a small sample of questions to a generation model registered in OGX - and asks it to return the ISO 639-1 code. Models listed in - *allowed_generation_models* are preferred when available. - - Parameters - ---------- - questions - Raw question texts to classify. Only the first five are sent to the - model. - llm_client - An authenticated :class:`OgxClient` instance. - allowed_generation_models - Optional whitelist of model identifiers to prefer. - - Returns - ------- - dict[str, str] | None - A dictionary with ``code`` and ``name`` keys when a non-English - language is detected, or ``None`` for English / on failure. - """ - sample_text = "\n".join(f"- {q}" for q in questions[:5]) - valid_codes = ", ".join(sorted(LANGUAGE_MAP.keys())) - - try: - models_response = llm_client.models.list() - models_list = models_response.data if hasattr(models_response, "data") else list(models_response) - registered_ids = {(m.identifier if hasattr(m, "identifier") else str(m.id)) for m in models_list} - - model_id: str | None = None - if allowed_generation_models: - for gm in allowed_generation_models: - if gm in registered_ids: - model_id = gm - break - if not model_id: - for m in models_list: - if hasattr(m, "model_type") and getattr(m, "model_type", "") == "llm": - model_id = m.identifier if hasattr(m, "identifier") else str(m.id) - break - if not model_id: - _logger.warning("No models available for LLM language detection.") - return None - - response = llm_client.chat.completions.create( - model=model_id, - messages=[ - { - "role": "system", - "content": ( - "You are a language detection assistant. " - "Given text samples, respond with ONLY the ISO 639-1 language code " - f"(one of: {valid_codes}). " - "Nothing else — just the code." - ), - }, - { - "role": "user", - "content": f"What language are these questions written in?\n{sample_text}", - }, - ], - max_completion_tokens=10, - temperature=0.0, - ) - if not response.choices: - _logger.warning("LLM returned empty choices for language detection.") - return None - raw = response.choices[0].message.content.strip().lower().replace('"', "").replace("'", "") - detected_code = raw.split()[0] if raw else None - - if not detected_code: - return None - - name = LANGUAGE_MAP.get(detected_code) - if not name: - _logger.warning("LLM returned unsupported language code: %s", detected_code) - return None - - _logger.info("Language detected via LLM: %s (%s)", detected_code, name) - return {"code": detected_code, "name": name} - - except Exception as exc: - _logger.warning("LLM language detection failed: %s", exc) - return None - - -def _detect_benchmark_language( - benchmark_df: pd.DataFrame, - llm_client: OgxClient, - generation_models: list[str] | None = None, - sample_size: int = 10, -) -> dict[str, str] | None: - """Detect the dominant language from benchmark question data. - - Extracts up to *sample_size* questions from the ``question`` column and - delegates to :func:`detect_language_via_llm` for classification. - - Parameters - ---------- - benchmark_df - DataFrame with a ``question`` column. - llm_client - An authenticated :class:`OgxClient` instance. - generation_models - Optional whitelist of model identifiers passed through to the LLM - detection step. - sample_size - Maximum number of questions to sample. - - Returns - ------- - dict[str, str] | None - A dictionary with ``code`` and ``name`` keys when a non-English - language is detected, or ``None`` for English / on failure. - """ - questions = benchmark_df["question"].dropna().astype(str).tolist() - if not questions: - return None - - sample = questions[:sample_size] - return _detect_language_via_llm(sample, llm_client, allowed_generation_models=generation_models) diff --git a/ai4rag/rag/foundation_models/base_model.py b/ai4rag/rag/foundation_models/base_model.py index 50c83ce..3e381c8 100644 --- a/ai4rag/rag/foundation_models/base_model.py +++ b/ai4rag/rag/foundation_models/base_model.py @@ -37,13 +37,17 @@ def __init__( system_message_text: str | None = None, user_message_text: str | None = None, context_template_text: str | None = None, + language_autodetect: bool = False, ): self.client = client self.model_id = model_id self.params = params + self.language_autodetect = language_autodetect self.system_message_text = system_message_text or get_system_message_text(model_name=model_id) self.user_message_text = ( - user_message_text if user_message_text is not None else get_user_message_text(model_name=model_id) + user_message_text + if user_message_text is not None + else get_user_message_text(model_name=model_id, language_autodetect=language_autodetect) ) self.context_template_text = ( context_template_text diff --git a/ai4rag/rag/foundation_models/ogx.py b/ai4rag/rag/foundation_models/ogx.py index 51945d8..a04db14 100644 --- a/ai4rag/rag/foundation_models/ogx.py +++ b/ai4rag/rag/foundation_models/ogx.py @@ -32,6 +32,7 @@ def __init__( system_message_text: str | None = None, user_message_text: str | None = None, context_template_text: str | None = None, + language_autodetect: bool = False, ): super().__init__( @@ -41,6 +42,7 @@ def __init__( system_message_text=system_message_text, user_message_text=user_message_text, context_template_text=context_template_text, + language_autodetect=language_autodetect, ) @property diff --git a/ai4rag/rag/foundation_models/utils.py b/ai4rag/rag/foundation_models/utils.py index 1aa0b9b..a19f272 100644 --- a/ai4rag/rag/foundation_models/utils.py +++ b/ai4rag/rag/foundation_models/utils.py @@ -7,6 +7,7 @@ from ai4rag.search_space.src.model_props import ( CONTEXT_TEXT_PLACEHOLDER, + DOCUMENT_NUMBER_PLACEHOLDER, QUESTION_PLACEHOLDER, REFERENCE_DOCUMENTS_PLACEHOLDER, ) @@ -29,11 +30,12 @@ def __init__( super().__init__() self.template_name = template_name - self._required_placeholders: tuple[str, ...] = ( - (CONTEXT_TEXT_PLACEHOLDER,) - if template_name == "context_template_text" - else (QUESTION_PLACEHOLDER, REFERENCE_DOCUMENTS_PLACEHOLDER) - ) + if template_name == "context_template_text": + self._required_placeholders: tuple[str, ...] = (CONTEXT_TEXT_PLACEHOLDER,) + self._optional_placeholders: tuple[str, ...] = (DOCUMENT_NUMBER_PLACEHOLDER,) + else: + self._required_placeholders = (QUESTION_PLACEHOLDER, REFERENCE_DOCUMENTS_PLACEHOLDER) + self._optional_placeholders = () def validate(self, _: object, value: T) -> T: """ @@ -63,18 +65,20 @@ def validate(self, _: object, value: T) -> T: raise TypeError(f"Expected {value!r} to be a str or None.") placeholders_count = 0 + allowed = self._required_placeholders + self._optional_placeholders for _, field_name, _, _ in Formatter().parse(value): if field_name is None: # when there is text NOT followed by a placeholder template continue - if field_name not in self._required_placeholders: + if field_name not in allowed: raise ConstraintsValidationError( f"Custom {field_name.split('_')[0]} template text got unexpected placeholder `{field_name}`, " - f"valid placeholders are `{self._required_placeholders}`." + f"valid placeholders are `{allowed}`." ) - placeholders_count += 1 + if field_name in self._required_placeholders: + placeholders_count += 1 if placeholders_count != len(self._required_placeholders): raise ConstraintsValidationError( @@ -114,24 +118,28 @@ def _validate_prompt_templates_placeholders( """ if template_name == "context_template_text": required_placeholders = (CONTEXT_TEXT_PLACEHOLDER,) + optional_placeholders = (DOCUMENT_NUMBER_PLACEHOLDER,) elif template_name == "user_message_text": required_placeholders = (QUESTION_PLACEHOLDER, REFERENCE_DOCUMENTS_PLACEHOLDER) + optional_placeholders = () else: raise ValueError(f"Cannot validate presence of expected template placeholders on field: {template_name}") placeholders_count = 0 + allowed = required_placeholders + optional_placeholders for _, field_name, _, _ in Formatter().parse(template_str): if field_name is None: # when there is text NOT followed by a placeholder template continue - if field_name not in required_placeholders: + if field_name not in allowed: raise ValueError( f"Custom {field_name.split('_')[0]} template text got unexpected placeholder `{field_name}`, " - f"valid placeholders are `{required_placeholders}`." + f"valid placeholders are `{allowed}`." ) - placeholders_count += 1 + if field_name in required_placeholders: + placeholders_count += 1 if placeholders_count != len(required_placeholders): raise ValueError( diff --git a/ai4rag/rag/template/simple_rag_template.py b/ai4rag/rag/template/simple_rag_template.py index cd5b0ac..df73018 100644 --- a/ai4rag/rag/template/simple_rag_template.py +++ b/ai4rag/rag/template/simple_rag_template.py @@ -95,8 +95,9 @@ def generate(self, question: str, **kwargs) -> dict[str, Any]: """ reference_documents = self.retriever.retrieve(question, **kwargs) - context = "\n".join( - [self.foundation_model.context_template_text.format(document=chunk.text) for chunk in reference_documents] + context = "\n\n".join( + self.foundation_model.context_template_text.format(document=chunk.text, doc_number=doc_number) + for doc_number, chunk in enumerate(reference_documents, start=1) ) user_message = self.foundation_model.user_message_text.format( diff --git a/ai4rag/search_space/src/model_props.py b/ai4rag/search_space/src/model_props.py index e853ed0..1a07e10 100644 --- a/ai4rag/search_space/src/model_props.py +++ b/ai4rag/search_space/src/model_props.py @@ -9,6 +9,7 @@ "QUESTION_PLACEHOLDER", "REFERENCE_DOCUMENTS_PLACEHOLDER", "CONTEXT_TEXT_PLACEHOLDER", + "DOCUMENT_NUMBER_PLACEHOLDER", "MULTILINGUAL_SUPPORT_INSTRUCTION_PLACEHOLDER", ] @@ -16,43 +17,65 @@ QUESTION_PLACEHOLDER = "question" REFERENCE_DOCUMENTS_PLACEHOLDER = "reference_documents" CONTEXT_TEXT_PLACEHOLDER = "document" +DOCUMENT_NUMBER_PLACEHOLDER = "doc_number" MULTILINGUAL_SUPPORT_INSTRUCTION_PLACEHOLDER = "multilingual_support" -# A mapping from model name into their corresponding prompt templates. -# The parameters for the prompt templates are QUESTION_PLACEHOLDER and REFERENCE_DOCUMENTS_PLACEHOLDER - _MULTILINGUAL_SUPPORT_ENABLED_PROMPT = ( - "Respond exclusively in the language of the question, " - "regardless of any other language used in the provided context. " - "Ensure that your entire response is in the same language as the question." + "You MUST write your entire answer in the same language as the question. " + "Do NOT respond in any other language, even if the documents use a different language. " + "Every word of your answer must match the question's language." ) _MULTILINGUAL_SUPPORT_DISABLED_PROMPT = ( - "Respond exclusively in English, " - "regardless of the language of the question or any other language used in the provided context. " - "Ensure that your entire response is in English only." + "You MUST write your entire answer in English only. " + "Do NOT use any other language, even if the question or documents are in another language. " + "Every word of your answer must be in English." +) + + +_RAG_GROUNDING_INSTRUCTION = ( + "Answer ONLY using information from the documents below. " + "Do not use outside knowledge. " + "If the documents do not contain the answer, say you do not have enough information." +) + + +_RAG_CITATION_INSTRUCTION = ( + "You MUST cite sources using [1], [2], etc. matching the document numbers for every factual claim." ) +_RAG_ANSWER_LENGTH_GUIDANCE = "max 150 words" + + +_RAG_ANSWER_PROMPT_LINE = f"Answer ({_RAG_ANSWER_LENGTH_GUIDANCE}, with citations):\n" + + +_RAG_SYSTEM_PREFIX = "You are a retrieval-augmented assistant. Answer using ONLY the provided documents. " + + +_DEFAULT_NUMBERED_CONTEXT_TEMPLATE = f"Document {{{DOCUMENT_NUMBER_PLACEHOLDER}}}:\n{{{CONTEXT_TEXT_PLACEHOLDER}}}\n" + + _DEFAULT_SYSTEM_MESSAGE_TEXT = ( - "Please answer the question I provide in the Question section below, " - "based solely on the information I provide in the Context section. " - "If the question is unanswerable, please say you cannot answer." + f"{_RAG_SYSTEM_PREFIX}" "If the question is unanswerable from the documents, say you cannot answer." ) _DEFAULT_USER_MESSAGE_TEXT = ( - f"\n\nContext:\n{{{REFERENCE_DOCUMENTS_PLACEHOLDER}}}:\n\n" - f"Question: {{{QUESTION_PLACEHOLDER}}}. \n" - "Again, please answer the question based on the context provided only. If the context is not related to " - "the question, just say you cannot answer. " - f"{{{MULTILINGUAL_SUPPORT_INSTRUCTION_PLACEHOLDER}}}" + f"{_RAG_GROUNDING_INSTRUCTION}\n" + f"{_RAG_CITATION_INSTRUCTION}\n\n" + f"Documents:\n{{{REFERENCE_DOCUMENTS_PLACEHOLDER}}}\n\n" + f"Question: {{{QUESTION_PLACEHOLDER}}}\n\n" + f"{_RAG_ANSWER_PROMPT_LINE}" + f"{{{MULTILINGUAL_SUPPORT_INSTRUCTION_PLACEHOLDER}}}\n" ) _DEFAULT_GRANITE_SYSTEM_MESSAGE_TEXT = ( + f"{_RAG_SYSTEM_PREFIX}" "You are Granite Chat, an AI language model developed by IBM. " "You are a cautious assistant. You carefully follow instructions. " "You are helpful and harmless and you follow ethical guidelines and promote positive behaviour." @@ -60,40 +83,39 @@ _DEFAULT_GRANITE_USER_MESSAGE_TEXT = ( - "You are an AI language model designed to function as a specialized Retrieval Augmented Generation (RAG) " - "assistant. When generating responses, prioritize correctness, i.e., ensure that your response is grounded in " - "context and user query. Always make sure that your response is relevant to the question. " - "\n" - "Answer Length: detailed" - "\n" - f"{{{REFERENCE_DOCUMENTS_PLACEHOLDER}}}" - "\n" - f"{{{MULTILINGUAL_SUPPORT_INSTRUCTION_PLACEHOLDER}}}" - "\n" - f"{{{QUESTION_PLACEHOLDER}}} " - "\n" - "\n" + f"{_RAG_GROUNDING_INSTRUCTION}\n" + f"{_RAG_CITATION_INSTRUCTION}\n\n" + "You are a specialized Retrieval Augmented Generation (RAG) assistant. " + "Prioritize correctness and ensure your response is grounded in the documents.\n\n" + f"Documents:\n{{{REFERENCE_DOCUMENTS_PLACEHOLDER}}}\n\n" + f"Question: {{{QUESTION_PLACEHOLDER}}}\n\n" + f"{_RAG_ANSWER_PROMPT_LINE}" + f"{{{MULTILINGUAL_SUPPORT_INSTRUCTION_PLACEHOLDER}}}\n" ) _DEFAULT_LLAMA_SYSTEM_MESSAGE_TEXT = ( + f"{_RAG_SYSTEM_PREFIX}" "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. " "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. " "Please ensure that your responses are socially unbiased and positive in nature.\n" "If a question does not make any sense, or is not factually coherent, explain why instead of answering " - "something not correct. If you don’t know the answer to a question, please don’t share false information.\n" + "something not correct. If you don't know the answer to a question, please don't share false information.\n" ) _DEFAULT_LLAMA_USER_MESSAGE_TEXT = ( - f"{{{REFERENCE_DOCUMENTS_PLACEHOLDER}}}\n" - f"[conversation]: {{{QUESTION_PLACEHOLDER}}}. Answer with no more than 150 words. If you cannot base your " - "answer on the given document, please state that you do not have an answer. " + f"{_RAG_GROUNDING_INSTRUCTION}\n" + f"{_RAG_CITATION_INSTRUCTION}\n\n" + f"Documents:\n{{{REFERENCE_DOCUMENTS_PLACEHOLDER}}}\n\n" + f"Question: {{{QUESTION_PLACEHOLDER}}}\n\n" + f"{_RAG_ANSWER_PROMPT_LINE}" f"{{{MULTILINGUAL_SUPPORT_INSTRUCTION_PLACEHOLDER}}}\n" ) _DEFAULT_MISTRAL_SYSTEM_MESSAGE_TEXT = ( + f"{_RAG_SYSTEM_PREFIX}" "You are a helpful, respectful and honest assistant. " "Always answer as helpfully as possible, while being safe. " "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. " @@ -104,32 +126,34 @@ _DEFAULT_MISTRAL_USER_MESSAGE_TEXT = ( - "Generate the next agent response by answering the question. You are provided several documents with titles. " - "If the answer comes from different documents please mention all possibilities and use the titles of documents " - "to separate between topics or domains. If you cannot base your answer on the given documents, " - f"please state that you do not have an answer. " - f"{{{REFERENCE_DOCUMENTS_PLACEHOLDER}}}\n\n" - f"{{{MULTILINGUAL_SUPPORT_INSTRUCTION_PLACEHOLDER}}}\n\n" - f"{{{QUESTION_PLACEHOLDER}}}" + f"{_RAG_GROUNDING_INSTRUCTION}\n" + f"{_RAG_CITATION_INSTRUCTION}\n\n" + f"Documents:\n{{{REFERENCE_DOCUMENTS_PLACEHOLDER}}}\n\n" + f"Question: {{{QUESTION_PLACEHOLDER}}}\n\n" + f"{_RAG_ANSWER_PROMPT_LINE}" + f"{{{MULTILINGUAL_SUPPORT_INSTRUCTION_PLACEHOLDER}}}\n" ) _DEFAULT_OPENAI_SYSTEM_MESSAGE_TEXT = ( - "You are a AI language model designed to function as a specialized Retrieval Augmented Generation (RAG) assistant. " + f"{_RAG_SYSTEM_PREFIX}" "When generating responses, prioritize correctness, i.e., ensure that your response is correct given the context " "and user query, and that it is grounded in the context. " "Furthermore, make sure that the response is supported by the given document or context. " "When the question cannot be answered using the context or document, output the following response: " "'I am sorry, I do not have the information you are looking for in my knowledge base.'. " "Always make sure that your response is relevant to the question. If an explanation is needed, " - "first provide the explanation or reasoning, and then give the final answer.\nAnswer Length: concise.\n\n" + "first provide the explanation or reasoning, and then give the final answer.\n\n" ) _DEFAULT_OPENAI_USER_MESSAGE_TEXT = ( - f"[Document]\n{{{REFERENCE_DOCUMENTS_PLACEHOLDER}}}\n[End]\n" - f"{{{QUESTION_PLACEHOLDER}}}. \n" - f"{{{MULTILINGUAL_SUPPORT_INSTRUCTION_PLACEHOLDER}}}" + f"{_RAG_GROUNDING_INSTRUCTION}\n" + f"{_RAG_CITATION_INSTRUCTION}\n\n" + f"Documents:\n{{{REFERENCE_DOCUMENTS_PLACEHOLDER}}}\n\n" + f"Question: {{{QUESTION_PLACEHOLDER}}}\n\n" + f"{_RAG_ANSWER_PROMPT_LINE}" + f"{{{MULTILINGUAL_SUPPORT_INSTRUCTION_PLACEHOLDER}}}\n" ) @@ -161,12 +185,11 @@ } -# A mapping from model names into their corresponding context template texts. These templates describe how each -# retrieved context is to be wrapped, before being integrated into a full RAG prompt text. -# The parameter for the context template text is CONTEXT_TEXT_PLACEHOLDER -_DEFAULT_GRANITE_CONTEXT_TEMPLATE = f"[Document]\n{{{CONTEXT_TEXT_PLACEHOLDER}}}\n[End]" -_DEFAULT_LLAMA_CONTEXT_TEMPLATE = f"[document]: {{{CONTEXT_TEXT_PLACEHOLDER}}}\n" -_DEFAULT_CONTEXT_TEMPLATE = f"{{{CONTEXT_TEXT_PLACEHOLDER}}}" +_DEFAULT_GRANITE_CONTEXT_TEMPLATE = _DEFAULT_NUMBERED_CONTEXT_TEMPLATE +_DEFAULT_LLAMA_CONTEXT_TEMPLATE = _DEFAULT_NUMBERED_CONTEXT_TEMPLATE +_DEFAULT_MISTRAL_CONTEXT_TEMPLATE = _DEFAULT_NUMBERED_CONTEXT_TEMPLATE +_DEFAULT_OPENAI_CONTEXT_TEMPLATE = _DEFAULT_NUMBERED_CONTEXT_TEMPLATE +_DEFAULT_CONTEXT_TEMPLATE = _DEFAULT_NUMBERED_CONTEXT_TEMPLATE _model_name_to_context_template_text = { "meta-llama/llama-3-1-70b-instruct": _DEFAULT_LLAMA_CONTEXT_TEMPLATE, @@ -175,7 +198,10 @@ "meta-llama/llama-4-maverick-17b-128e-instruct-fp8": _DEFAULT_LLAMA_CONTEXT_TEMPLATE, "ibm/granite-3-8b-instruct": _DEFAULT_GRANITE_CONTEXT_TEMPLATE, "ibm/granite-3-3-8b-instruct": _DEFAULT_GRANITE_CONTEXT_TEMPLATE, - "openai/gpt-oss-120b": _DEFAULT_CONTEXT_TEMPLATE, + "mistralai/mistral-small-3-1-24b-instruct-2503": _DEFAULT_MISTRAL_CONTEXT_TEMPLATE, + "mistralai/mistral-medium-2505": _DEFAULT_MISTRAL_CONTEXT_TEMPLATE, + "mistralai/mistral-large": _DEFAULT_MISTRAL_CONTEXT_TEMPLATE, + "openai/gpt-oss-120b": _DEFAULT_OPENAI_CONTEXT_TEMPLATE, } @@ -183,8 +209,8 @@ def get_context_template_text(model_name: str) -> str: """ Get a model-specific context template text. - The context template text is a template with one placeholder: "context_text". - This field should be populated before use within a RAG prompt. + The context template text is a template with placeholders ``document`` and, + optionally, ``doc_number``. Parameters ---------- @@ -203,6 +229,8 @@ def get_context_template_text(model_name: str) -> str: context_template = _DEFAULT_GRANITE_CONTEXT_TEMPLATE elif "llama" in model_name: context_template = _DEFAULT_LLAMA_CONTEXT_TEMPLATE + elif "mistral" in model_name: + context_template = _DEFAULT_MISTRAL_CONTEXT_TEMPLATE else: context_template = _DEFAULT_CONTEXT_TEMPLATE @@ -233,13 +261,15 @@ def get_system_message_text(model_name: str) -> str: system_message_text = _DEFAULT_LLAMA_SYSTEM_MESSAGE_TEXT elif "mistral" in model_name: system_message_text = _DEFAULT_MISTRAL_SYSTEM_MESSAGE_TEXT + elif "openai" in model_name or "gpt" in model_name: + system_message_text = _DEFAULT_OPENAI_SYSTEM_MESSAGE_TEXT else: system_message_text = _DEFAULT_SYSTEM_MESSAGE_TEXT return system_message_text -def get_user_message_text(model_name: str, language_autodetect: bool = True) -> str: +def get_user_message_text(model_name: str, language_autodetect: bool = False) -> str: """ Get a model-specific prompt text. @@ -253,6 +283,7 @@ def get_user_message_text(model_name: str, language_autodetect: bool = True) -> language_autodetect : bool If True, language of the question will be automatically detected. + Defaults to False (English-only responses) for stronger faithfulness on English benchmarks. Returns ------- @@ -268,6 +299,8 @@ def get_user_message_text(model_name: str, language_autodetect: bool = True) -> user_message_text = _DEFAULT_LLAMA_USER_MESSAGE_TEXT elif "mistral" in model_name: user_message_text = _DEFAULT_MISTRAL_USER_MESSAGE_TEXT + elif "openai" in model_name or "gpt" in model_name: + user_message_text = _DEFAULT_OPENAI_USER_MESSAGE_TEXT else: user_message_text = _DEFAULT_USER_MESSAGE_TEXT diff --git a/dev_utils/mocks.py b/dev_utils/mocks.py index c46543c..5a6986e 100644 --- a/dev_utils/mocks.py +++ b/dev_utils/mocks.py @@ -20,6 +20,7 @@ def __init__( system_message_text: str | None = None, user_message_text: str | None = None, context_template_text: str | None = None, + language_autodetect: bool = False, ): super().__init__( client=client, @@ -28,6 +29,7 @@ def __init__( system_message_text=system_message_text, user_message_text=user_message_text, context_template_text=context_template_text, + language_autodetect=language_autodetect, ) def chat(self, messages: list[MessageTyped]) -> list[MessageTyped]: diff --git a/tests/unit/ai4rag/assets_generator/test_pattern_builder.py b/tests/unit/ai4rag/assets_generator/test_pattern_builder.py index c4a0ee4..3752333 100644 --- a/tests/unit/ai4rag/assets_generator/test_pattern_builder.py +++ b/tests/unit/ai4rag/assets_generator/test_pattern_builder.py @@ -97,19 +97,6 @@ def test_returns_same_dict(self): result = build_pattern_json(pattern) assert result is pattern - def test_no_detected_language_by_default(self): - """When detected_language is None, no detected_language key appears in generation.""" - pattern = _make_pattern() - build_pattern_json(pattern) - assert "detected_language" not in pattern["settings"]["generation"] - - def test_detected_language_injected(self): - """Non-English language detection must inject detected_language into generation.""" - pattern = _make_pattern() - lang = {"code": "de", "name": "German"} - build_pattern_json(pattern, detected_language=lang) - assert pattern["settings"]["generation"]["detected_language"] == lang - def test_hybrid_rrf_ranking_options(self): """Hybrid search with RRF ranker must set ranker and impact_factor in ranking_options.""" pattern = _make_pattern() diff --git a/tests/unit/ai4rag/components/optimization/test_language_detection.py b/tests/unit/ai4rag/components/optimization/test_language_detection.py deleted file mode 100644 index 5447736..0000000 --- a/tests/unit/ai4rag/components/optimization/test_language_detection.py +++ /dev/null @@ -1,256 +0,0 @@ -# ----------------------------------------------------------------------------- -# Copyright IBM Corp. 2026 -# SPDX-License-Identifier: Apache-2.0 -# ----------------------------------------------------------------------------- -from __future__ import annotations - -from unittest.mock import MagicMock - -import pandas as pd -import pytest - -from ai4rag.components.optimization.search_space_preparation import ( - LANGUAGE_MAP, -) -from ai4rag.components.optimization.search_space_preparation import ( - _detect_benchmark_language as detect_benchmark_language, -) -from ai4rag.components.optimization.search_space_preparation import _detect_language_via_llm as detect_language_via_llm - -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - -@pytest.fixture() -def mock_ogx_client() -> MagicMock: - """Return a MagicMock that behaves like an OgxClient with one LLM model.""" - mock_model = MagicMock() - mock_model.identifier = "test-model" - mock_model.model_type = "llm" - - mock_models_response = MagicMock() - mock_models_response.data = [mock_model] - - mock_choice = MagicMock() - mock_choice.message.content = "ja" - mock_response = MagicMock() - mock_response.choices = [mock_choice] - - mock_client = MagicMock() - mock_client.models.list.return_value = mock_models_response - mock_client.chat.completions.create.return_value = mock_response - return mock_client - - -@pytest.fixture() -def sample_questions() -> list[str]: - """Return a short list of sample questions.""" - return [ - "東京の天気はどうですか?", - "日本の首都はどこですか?", - "富士山の高さは?", - ] - - -# --------------------------------------------------------------------------- -# LANGUAGE_MAP -# --------------------------------------------------------------------------- - - -class TestLanguageMap: - """Verify the static LANGUAGE_MAP contents.""" - - def test_known_codes_present(self): - """Well-known ISO 639-1 codes must resolve to their language names.""" - assert LANGUAGE_MAP["ja"] == "Japanese" - assert LANGUAGE_MAP["en"] == "English" - assert LANGUAGE_MAP["pl"] == "Polish" - assert LANGUAGE_MAP["de"] == "German" - assert LANGUAGE_MAP["fr"] == "French" - assert LANGUAGE_MAP["ko"] == "Korean" - - def test_chinese_variants(self): - """Both simplified and traditional Chinese codes must be present.""" - assert LANGUAGE_MAP["zh-cn"] == "Chinese" - assert LANGUAGE_MAP["zh-tw"] == "Chinese" - - def test_all_values_are_nonempty_strings(self): - """Every value in LANGUAGE_MAP must be a non-empty human-readable name.""" - for code, name in LANGUAGE_MAP.items(): - assert isinstance(code, str) and code, f"Invalid code: {code!r}" - assert isinstance(name, str) and name, f"Invalid name for {code!r}: {name!r}" - - -# --------------------------------------------------------------------------- -# detect_language_via_llm -# --------------------------------------------------------------------------- - - -class TestDetectLanguageViaLlm: - """Tests for the LLM-based language detection function.""" - - def test_detects_japanese(self, mock_ogx_client, sample_questions): - """When the LLM returns 'ja', the result must contain the correct code and name.""" - result = detect_language_via_llm(sample_questions, mock_ogx_client) - - assert result is not None - assert result == {"code": "ja", "name": "Japanese"} - mock_ogx_client.chat.completions.create.assert_called_once() - - def test_english_returns_none(self, mock_ogx_client, sample_questions): - """English is the default language, so detection must return None.""" - mock_ogx_client.chat.completions.create.return_value.choices[0].message.content = "en" - - result = detect_language_via_llm(sample_questions, mock_ogx_client) - - # English maps to a valid entry in LANGUAGE_MAP, so it returns the dict. - # The contract says "None for English" only at the detect_benchmark_language - # level. At this level the function returns the mapping when the code is - # valid, regardless of which language it is. - # Re-reading the source: the function returns {"code": ..., "name": ...} - # for ANY valid code, including English. - assert result == {"code": "en", "name": "English"} - - def test_api_failure_returns_none(self, mock_ogx_client, sample_questions): - """An exception from the OGX client must be swallowed, returning None.""" - mock_ogx_client.chat.completions.create.side_effect = RuntimeError("API unavailable") - - result = detect_language_via_llm(sample_questions, mock_ogx_client) - - assert result is None - - def test_unsupported_language_code_returns_none(self, mock_ogx_client, sample_questions): - """An ISO code not present in LANGUAGE_MAP must return None.""" - mock_ogx_client.chat.completions.create.return_value.choices[0].message.content = "xx" - - result = detect_language_via_llm(sample_questions, mock_ogx_client) - - assert result is None - - def test_no_models_available_returns_none(self, sample_questions): - """When no models are registered, the function must return None.""" - mock_model_response = MagicMock() - mock_model_response.data = [] - - mock_client = MagicMock() - mock_client.models.list.return_value = mock_model_response - - result = detect_language_via_llm(sample_questions, mock_client) - - assert result is None - mock_client.chat.completions.create.assert_not_called() - - def test_prefers_allowed_generation_model(self, sample_questions): - """When allowed_generation_models is set, the preferred model must be used.""" - preferred_model = MagicMock() - preferred_model.identifier = "preferred-llm" - preferred_model.model_type = "llm" - - other_model = MagicMock() - other_model.identifier = "other-llm" - other_model.model_type = "llm" - - mock_models_response = MagicMock() - mock_models_response.data = [other_model, preferred_model] - - mock_choice = MagicMock() - mock_choice.message.content = "ja" - mock_response = MagicMock() - mock_response.choices = [mock_choice] - - mock_client = MagicMock() - mock_client.models.list.return_value = mock_models_response - mock_client.chat.completions.create.return_value = mock_response - - detect_language_via_llm(sample_questions, mock_client, allowed_generation_models=["preferred-llm"]) - - call_kwargs = mock_client.chat.completions.create.call_args - assert call_kwargs[1]["model"] == "preferred-llm" or call_kwargs.kwargs["model"] == "preferred-llm" - - def test_samples_at_most_five_questions(self, mock_ogx_client): - """Only the first five questions should appear in the prompt.""" - many_questions = [f"Question {i}" for i in range(20)] - - detect_language_via_llm(many_questions, mock_ogx_client) - - call_kwargs = mock_ogx_client.chat.completions.create.call_args - user_content = call_kwargs[1]["messages"][1]["content"] - # The prompt enumerates "- Q" lines; at most 5 should appear. - assert user_content.count("- Question") == 5 - - def test_empty_llm_response_returns_none(self, mock_ogx_client, sample_questions): - """A blank response from the LLM must return None.""" - mock_ogx_client.chat.completions.create.return_value.choices[0].message.content = " " - - result = detect_language_via_llm(sample_questions, mock_ogx_client) - - assert result is None - - def test_models_list_failure_returns_none(self, sample_questions): - """An exception during models.list() must be swallowed.""" - mock_client = MagicMock() - mock_client.models.list.side_effect = ConnectionError("timeout") - - result = detect_language_via_llm(sample_questions, mock_client) - - assert result is None - - -# --------------------------------------------------------------------------- -# detect_benchmark_language -# --------------------------------------------------------------------------- - - -class TestDetectBenchmarkLanguage: - """Tests for the DataFrame-level language detection wrapper.""" - - def test_detects_language_from_dataframe(self, mock_ogx_client): - """A DataFrame with a 'question' column must yield detection results.""" - df = pd.DataFrame({"question": ["東京の天気は?", "富士山の高さは?", "日本の首都は?"]}) - - result = detect_benchmark_language(df, mock_ogx_client) - - assert result is not None - assert result["code"] == "ja" - mock_ogx_client.chat.completions.create.assert_called_once() - - def test_empty_dataframe_returns_none(self, mock_ogx_client): - """An empty DataFrame must short-circuit to None without calling the LLM.""" - df = pd.DataFrame({"question": pd.Series([], dtype=str)}) - - result = detect_benchmark_language(df, mock_ogx_client) - - assert result is None - mock_ogx_client.chat.completions.create.assert_not_called() - - def test_all_nan_questions_returns_none(self, mock_ogx_client): - """When every question value is NaN, the function must return None.""" - df = pd.DataFrame({"question": [None, None, None]}) - - result = detect_benchmark_language(df, mock_ogx_client) - - assert result is None - mock_ogx_client.chat.completions.create.assert_not_called() - - def test_respects_sample_size(self, mock_ogx_client): - """The sample_size parameter must cap the number of questions forwarded.""" - df = pd.DataFrame({"question": [f"Q{i}" for i in range(50)]}) - - detect_benchmark_language(df, mock_ogx_client, sample_size=3) - - call_kwargs = mock_ogx_client.chat.completions.create.call_args - user_content = call_kwargs[1]["messages"][1]["content"] - # detect_language_via_llm further caps to 5, but sample_size=3 means - # only 3 questions are passed in. - assert user_content.count("- Q") == 3 - - def test_passes_generation_models_through(self, mock_ogx_client): - """The generation_models parameter must reach detect_language_via_llm.""" - df = pd.DataFrame({"question": ["Hello?"]}) - - detect_benchmark_language(df, mock_ogx_client, generation_models=["custom-model"]) - - # The function should still call the LLM; the model selection logic - # inside detect_language_via_llm handles the allowed list. - mock_ogx_client.chat.completions.create.assert_called_once() diff --git a/tests/unit/ai4rag/components/optimization/test_rag_optimization.py b/tests/unit/ai4rag/components/optimization/test_rag_optimization.py index 6e082ab..5ee6efe 100644 --- a/tests/unit/ai4rag/components/optimization/test_rag_optimization.py +++ b/tests/unit/ai4rag/components/optimization/test_rag_optimization.py @@ -12,12 +12,9 @@ DEFAULT_MAX_RAG_PATTERNS, MIN_MAX_RAG_PATTERNS_RANGE, SUPPORTED_OPTIMIZATION_METRICS, - _inject_language_instructions, _validate_optimization_settings, run_rag_optimization, ) -from ai4rag.search_space.src.parameter import Parameter -from ai4rag.search_space.src.search_space import AI4RAGSearchSpace # --------------------------------------------------------------------------- # Fixtures @@ -30,15 +27,6 @@ def mock_ogx_client() -> MagicMock: return MagicMock() -@pytest.fixture() -def foundation_model_stub() -> MagicMock: - """Return a MagicMock that acts as a foundation model with message attributes.""" - fm = MagicMock() - fm.system_message_text = "You are a helpful assistant." - fm.user_message_text = "Answer: {question}" - return fm - - # --------------------------------------------------------------------------- # _validate_optimization_settings # --------------------------------------------------------------------------- @@ -130,99 +118,6 @@ def test_extra_keys_preserved(self): assert result["metric"] == "faithfulness" -# --------------------------------------------------------------------------- -# _inject_language_instructions -# --------------------------------------------------------------------------- - - -class TestInjectLanguageInstructions: - """Tests for injecting language-specific instructions into foundation models.""" - - @staticmethod - def _make_search_space(foundation_models: list) -> AI4RAGSearchSpace: - """Build a minimal AI4RAGSearchSpace with required parameters.""" - dummy_embedding = MagicMock() - dummy_embedding.context_length = 8192 - return AI4RAGSearchSpace( - params=[ - Parameter("foundation_model", "C", values=foundation_models), - Parameter("embedding_model", "C", values=[dummy_embedding]), - ] - ) - - def test_injects_language_into_system_message(self, foundation_model_stub): - """The system message must be prefixed with the language instruction.""" - search_space = self._make_search_space([foundation_model_stub]) - - _inject_language_instructions(search_space, {"code": "ja", "name": "Japanese"}) - - fm = search_space["foundation_model"].values[0] - assert fm.system_message_text.startswith("You MUST respond in Japanese.") - - def test_injects_language_into_user_message(self, foundation_model_stub): - """The user message must be appended with the language instruction.""" - search_space = self._make_search_space([foundation_model_stub]) - - _inject_language_instructions(search_space, {"code": "ja", "name": "Japanese"}) - - fm = search_space["foundation_model"].values[0] - assert fm.user_message_text.endswith("You MUST respond in Japanese.") - - def test_english_code_skips_injection(self, foundation_model_stub): - """English language detection must not modify the messages.""" - original_sys = foundation_model_stub.system_message_text - original_usr = foundation_model_stub.user_message_text - - search_space = self._make_search_space([foundation_model_stub]) - - _inject_language_instructions(search_space, {"code": "en", "name": "English"}) - - fm = search_space["foundation_model"].values[0] - assert fm.system_message_text == original_sys - assert fm.user_message_text == original_usr - - def test_empty_name_skips_injection(self, foundation_model_stub): - """An empty language name must not inject any instruction.""" - original_sys = foundation_model_stub.system_message_text - - search_space = self._make_search_space([foundation_model_stub]) - - _inject_language_instructions(search_space, {"code": "ja", "name": ""}) - - fm = search_space["foundation_model"].values[0] - assert fm.system_message_text == original_sys - - def test_multiple_foundation_models(self): - """All foundation models in the search space must be updated.""" - fm1 = MagicMock() - fm1.system_message_text = "sys1" - fm1.user_message_text = "usr1" - - fm2 = MagicMock() - fm2.system_message_text = "sys2" - fm2.user_message_text = "usr2" - - search_space = self._make_search_space([fm1, fm2]) - - _inject_language_instructions(search_space, {"code": "ko", "name": "Korean"}) - - for fm in search_space["foundation_model"].values: - assert "You MUST respond in Korean." in fm.system_message_text - assert "You MUST respond in Korean." in fm.user_message_text - - def test_none_system_message_gets_instruction_prepended(self): - """A model with None system_message_text must still get the language instruction.""" - fm = MagicMock() - fm.system_message_text = None - fm.user_message_text = None - - search_space = self._make_search_space([fm]) - - _inject_language_instructions(search_space, {"code": "de", "name": "German"}) - - assert fm.system_message_text.startswith("You MUST respond in German.") - - # --------------------------------------------------------------------------- # run_rag_optimization -- input validation only # --------------------------------------------------------------------------- diff --git a/tests/unit/ai4rag/components/optimization/test_search_space_prep.py b/tests/unit/ai4rag/components/optimization/test_search_space_prep.py index 752f05a..76d8e4e 100644 --- a/tests/unit/ai4rag/components/optimization/test_search_space_prep.py +++ b/tests/unit/ai4rag/components/optimization/test_search_space_prep.py @@ -34,7 +34,6 @@ def simple_report() -> SearchSpaceReport: "foundation_model": ["model-a"], "embedding_model": ["emb-a"], }, - detected_language={"code": "ja", "name": "Japanese"}, ) @@ -94,31 +93,6 @@ def test_save_yaml_creates_file(self, simple_report, tmp_path: Path): data = yml.safe_load(out_file.read_text()) assert isinstance(data, dict) - def test_save_yaml_includes_detected_language(self, simple_report, tmp_path: Path): - """When detected_language is set, it must appear in the YAML output.""" - import yaml as yml - - out_file = tmp_path / "report.yaml" - simple_report.save_yaml(out_file) - - data = yml.safe_load(out_file.read_text()) - assert data["detected_language"] == {"code": "ja", "name": "Japanese"} - - def test_save_yaml_omits_language_when_none(self, tmp_path: Path): - """When detected_language is None, the key must not appear in output.""" - import yaml as yml - - report = SearchSpaceReport( - search_space={"chunk_size": [256]}, - selected_models={"foundation_model": []}, - detected_language=None, - ) - out_file = tmp_path / "report.yaml" - report.save_yaml(out_file) - - data = yml.safe_load(out_file.read_text()) - assert "detected_language" not in data - def test_save_yaml_creates_parent_directories(self, simple_report, tmp_path: Path): """save_yaml must create intermediate directories if they do not exist.""" out_file = tmp_path / "nested" / "dir" / "report.yaml" diff --git a/tests/unit/ai4rag/rag/foundation_models/test_base_model.py b/tests/unit/ai4rag/rag/foundation_models/test_base_model.py index 4a40739..0e58efd 100644 --- a/tests/unit/ai4rag/rag/foundation_models/test_base_model.py +++ b/tests/unit/ai4rag/rag/foundation_models/test_base_model.py @@ -54,6 +54,24 @@ def test_init(self, mock_client, model_params): assert model.client == mock_client assert model.model_id == "test-model-123" assert model.params == model_params + assert model.language_autodetect is False + + def test_language_autodetect_passed_to_default_user_message(self, mock_client, model_params, mocker): + """Test that language_autodetect is forwarded when resolving default user message text.""" + mock_get_user_message = mocker.patch( + "ai4rag.rag.foundation_models.base_model.get_user_message_text", + return_value="Question: {question}\nReferences: {reference_documents}", + ) + ConcreteFoundationModel( + client=mock_client, + model_id="meta-llama/llama-3-1-8b-instruct", + params=model_params, + language_autodetect=True, + ) + mock_get_user_message.assert_called_once_with( + model_name="meta-llama/llama-3-1-8b-instruct", + language_autodetect=True, + ) def test_repr(self, foundation_model): """Test __repr__ returns model_id.""" diff --git a/tests/unit/ai4rag/rag/foundation_models/test_ogx.py b/tests/unit/ai4rag/rag/foundation_models/test_ogx.py index bed995c..e6670bf 100644 --- a/tests/unit/ai4rag/rag/foundation_models/test_ogx.py +++ b/tests/unit/ai4rag/rag/foundation_models/test_ogx.py @@ -223,7 +223,7 @@ def test_user_message_text_default_when_none( context_template_text=valid_context_template, system_message_text=valid_system_message, ) - mock_get_user_message.assert_called_once_with(model_name="llama-3-70b") + mock_get_user_message.assert_called_once_with(model_name="llama-3-70b", language_autodetect=False) assert "Default user message" in model.user_message_text def test_context_template_text_custom(self, mock_ogx_client, valid_user_message_template, valid_system_message): diff --git a/tests/unit/ai4rag/rag/foundation_models/test_utils.py b/tests/unit/ai4rag/rag/foundation_models/test_utils.py index c6d3ac2..ec1fdf7 100644 --- a/tests/unit/ai4rag/rag/foundation_models/test_utils.py +++ b/tests/unit/ai4rag/rag/foundation_models/test_utils.py @@ -11,6 +11,7 @@ ) from ai4rag.search_space.src.model_props import ( CONTEXT_TEXT_PLACEHOLDER, + DOCUMENT_NUMBER_PLACEHOLDER, QUESTION_PLACEHOLDER, REFERENCE_DOCUMENTS_PLACEHOLDER, ) @@ -54,6 +55,12 @@ def test_context_template_missing_placeholder(self): assert "Incorrect number of placeholders" in str(exc_info.value) assert "expected 1 but got 0" in str(exc_info.value) + def test_context_template_with_doc_number(self): + """Test that context template may include optional doc_number placeholder.""" + template = f"Document {{{DOCUMENT_NUMBER_PLACEHOLDER}}}:\n{{{CONTEXT_TEXT_PLACEHOLDER}}}\n" + result = _validate_prompt_templates_placeholders(template, "context_template_text") + assert result == template + def test_user_message_missing_one_placeholder(self): """Test that user message with only one placeholder raises ValueError.""" template = f"Question: {{{QUESTION_PLACEHOLDER}}}" @@ -225,6 +232,13 @@ def test_context_template_missing_placeholder(self): assert "Incorrect number of placeholders" in str(exc_info.value) assert "expected 1 but got 0" in str(exc_info.value) + def test_context_template_with_doc_number(self): + """Test that context template may include optional doc_number placeholder.""" + validator = RAGPromptTemplateString("context_template_text") + template = f"Document {{{DOCUMENT_NUMBER_PLACEHOLDER}}}:\n{{{CONTEXT_TEXT_PLACEHOLDER}}}\n" + result = validator.validate(None, template) + assert result == template + def test_user_message_missing_one_placeholder(self): """Test that user message with only one placeholder raises ConstraintsValidationError.""" validator = RAGPromptTemplateString("user_message_text") diff --git a/tests/unit/ai4rag/rag/template/test_simple_rag_template.py b/tests/unit/ai4rag/rag/template/test_simple_rag_template.py index 1a7eedb..56c6ec2 100644 --- a/tests/unit/ai4rag/rag/template/test_simple_rag_template.py +++ b/tests/unit/ai4rag/rag/template/test_simple_rag_template.py @@ -445,6 +445,24 @@ def test_generate_builds_context_from_retrieved_documents( assert "Document: Relevant document 2" in user_message assert "What is AI?" in user_message + def test_generate_numbers_documents_when_template_includes_doc_number( + self, + mock_foundation_model, + mock_retriever, + ): + """Test that doc_number is passed when the context template includes it.""" + mock_foundation_model.context_template_text = "Document {doc_number}:\n{document}\n" + rag = SimpleRAG( + foundation_model=mock_foundation_model, + retriever=mock_retriever, + ) + + rag.generate("What is AI?") + + user_message = mock_foundation_model.chat.call_args.kwargs["messages"][1]["content"] + assert "Document 1:\nRelevant document 1" in user_message + assert "Document 2:\nRelevant document 2" in user_message + def test_generate_calls_foundation_model_chat( self, mock_foundation_model, diff --git a/tests/unit/ai4rag/search_space/src/test_model_props.py b/tests/unit/ai4rag/search_space/src/test_model_props.py new file mode 100644 index 0000000..e5d5c15 --- /dev/null +++ b/tests/unit/ai4rag/search_space/src/test_model_props.py @@ -0,0 +1,91 @@ +# ----------------------------------------------------------------------------- +# Copyright IBM Corp. 2026 +# SPDX-License-Identifier: Apache-2.0 +# ----------------------------------------------------------------------------- +"""Tests for default RAG prompt templates.""" + +import pytest + +from ai4rag.search_space.src.model_props import ( + DOCUMENT_NUMBER_PLACEHOLDER, + get_context_template_text, + get_system_message_text, + get_user_message_text, +) + + +@pytest.mark.parametrize( + "model_name", + [ + "meta-llama/llama-3-1-8b-instruct", + "ibm/granite-3-8b-instruct", + "mistralai/mistral-large", + "openai/gpt-oss-120b", + "unknown-model", + ], +) +def test_user_message_includes_grounding_and_citations(model_name: str): + user_message = get_user_message_text(model_name) + assert "Answer ONLY" in user_message + assert "MUST cite sources" in user_message + assert "{reference_documents}" in user_message + assert "{question}" in user_message + + +@pytest.mark.parametrize( + "model_name", + [ + "meta-llama/llama-3-1-8b-instruct", + "ibm/granite-3-8b-instruct", + "mistralai/mistral-large", + "openai/gpt-oss-120b", + "vllm-inference-gpu-llama/redhataillama-31-8b-instruct", + ], +) +def test_system_message_includes_rag_prefix(model_name: str): + system_message = get_system_message_text(model_name) + assert "retrieval-augmented assistant" in system_message + assert "ONLY the provided documents" in system_message + + +@pytest.mark.parametrize( + "model_name", + [ + "meta-llama/llama-3-1-8b-instruct", + "ibm/granite-3-8b-instruct", + "mistralai/mistral-large", + "openai/gpt-oss-120b", + "unknown-model", + ], +) +def test_context_template_numbers_documents(model_name: str): + context_template = get_context_template_text(model_name) + assert f"{{{DOCUMENT_NUMBER_PLACEHOLDER}}}" in context_template + assert "{document}" in context_template + + +def test_language_autodetect_defaults_to_english_only(): + user_message = get_user_message_text("meta-llama/llama-3-1-8b-instruct") + assert "You MUST write your entire answer in English only" in user_message + assert "Do NOT use any other language" in user_message + + +def test_language_autodetect_enabled_uses_strong_question_language_instruction(): + user_message = get_user_message_text("meta-llama/llama-3-1-8b-instruct", language_autodetect=True) + assert "You MUST write your entire answer in the same language as the question" in user_message + assert "Do NOT respond in any other language" in user_message + + +@pytest.mark.parametrize( + "model_name", + [ + "meta-llama/llama-3-1-8b-instruct", + "ibm/granite-3-8b-instruct", + "mistralai/mistral-large", + "openai/gpt-oss-120b", + "unknown-model", + ], +) +def test_user_message_includes_consistent_answer_length(model_name: str): + user_message = get_user_message_text(model_name) + assert "Answer (max 150 words, with citations):" in user_message From b97a10245a9b1f249e519201bbb756396ae92432 Mon Sep 17 00:00:00 2001 From: Jakub Walaszczyk Date: Thu, 25 Jun 2026 13:32:00 +0200 Subject: [PATCH 07/16] ci: Remove mike dependency, documentation versioning and update documentation publish script. Signed-off-by: Jakub Walaszczyk Assisted-by: Claude Code Signed-off-by: Lukasz Cmielowski --- .github/workflows/docs-deploy.yml | 27 +++++----------------- mkdocs.yml | 4 ---- pyproject.toml | 1 - uv.lock | 38 ------------------------------- 4 files changed, 6 insertions(+), 64 deletions(-) diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index 9c291d9..81eeaa9 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -3,10 +3,9 @@ name: Deploy Documentation on: workflow_dispatch: inputs: - branch: - description: 'Branch to build docs from' + tag: + description: 'Git tag to build docs from (e.g. v0.8.1)' required: true - default: 'main' type: string permissions: @@ -20,9 +19,8 @@ jobs: - name: Checkout repository uses: actions/checkout@v4 with: - ref: ${{ github.event.inputs.branch || github.ref }} - fetch-depth: 0 # Fetch all history for git-revision-date-localized plugin - token: ${{ secrets.GITHUB_TOKEN }} # Use GITHUB_TOKEN for authentication + ref: ${{ inputs.tag }} + fetch-depth: 0 - name: Set up uv uses: astral-sh/setup-uv@v6 @@ -32,21 +30,8 @@ jobs: - name: Install dependencies run: uv sync --extra docs - - name: Extract version - id: version - run: | - VERSION=$(uv run python -c "import ai4rag; print(ai4rag.__version__)") - # Strip 'v' prefix if present - VERSION=${VERSION#v} - echo "docs_version=$VERSION" >> $GITHUB_OUTPUT - echo "Extracted version: $VERSION" - - - name: Configure Git for mike + - name: Deploy documentation run: | git config user.name "github-actions[bot]" git config user.email "github-actions[bot]@users.noreply.github.com" - - - name: Deploy documentation (manual trigger) - if: github.event_name == 'workflow_dispatch' - run: | - uv run mike deploy --push ${{ steps.version.outputs.docs_version }} + uv run mkdocs gh-deploy --force diff --git a/mkdocs.yml b/mkdocs.yml index 859a666..3961c7a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -220,10 +220,6 @@ nav: # Extra configuration extra: - version: - provider: mike - default: latest - social: - icon: fontawesome/brands/github link: https://github.com/IBM/ai4rag diff --git a/pyproject.toml b/pyproject.toml index 336a645..136fd7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,7 +76,6 @@ docs = [ "mkdocstrings[python]>=0.25.0", "mkdocs-git-revision-date-localized-plugin>=1.2.0", "mkdocs-minify-plugin>=0.8.0", - "mike>=2.0.0", ] [tool.setuptools.dynamic] diff --git a/uv.lock b/uv.lock index 1000093..8517aef 100644 --- a/uv.lock +++ b/uv.lock @@ -57,7 +57,6 @@ dev = [ { name = "dotenv" }, { name = "ipykernel" }, { name = "isort" }, - { name = "mike" }, { name = "mkdocs" }, { name = "mkdocs-git-revision-date-localized-plugin" }, { name = "mkdocs-material" }, @@ -71,7 +70,6 @@ dev = [ { name = "pytest-mock" }, ] docs = [ - { name = "mike" }, { name = "mkdocs" }, { name = "mkdocs-git-revision-date-localized-plugin" }, { name = "mkdocs-material" }, @@ -102,7 +100,6 @@ requires-dist = [ { name = "langchain", specifier = "~=1.1.3" }, { name = "langchain-chroma", specifier = "~=1.1.0" }, { name = "langchain-text-splitters", specifier = "~=1.1.0" }, - { name = "mike", marker = "extra == 'docs'", specifier = ">=2.0.0" }, { name = "mkdocs", marker = "extra == 'docs'", specifier = "~=1.6.0" }, { name = "mkdocs-git-revision-date-localized-plugin", marker = "extra == 'docs'", specifier = ">=1.2.0" }, { name = "mkdocs-material", marker = "extra == 'docs'", specifier = "~=9.5" }, @@ -1913,23 +1910,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2c/19/04f9b178c2d8a15b076c8b5140708fa6ffc5601fb6f1e975537072df5b2a/mergedeep-1.3.4-py3-none-any.whl", hash = "sha256:70775750742b25c0d8f36c55aed03d24c3384d17c951b3175d898bd778ef0307", size = 6354, upload-time = "2021-02-05T18:55:29.583Z" }, ] -[[package]] -name = "mike" -version = "2.2.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "jinja2" }, - { name = "mkdocs" }, - { name = "pyparsing" }, - { name = "pyyaml" }, - { name = "pyyaml-env-tag" }, - { name = "verspec" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/b4/47/fa87e9d56bef16cdfe34b059a437e8c6f7ec6f1b9c378871c3cf95ebea9c/mike-2.2.0.tar.gz", hash = "sha256:1e3858e32c0f125aac14432fc7848434358f9ae0962c5c5cde387ad47f6ad25e", size = 38450, upload-time = "2026-04-14T04:59:03.944Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/76/8e/56ccb09c7232a55403a7637caa21922f3b65901a37f5e8bdb405d0de0946/mike-2.2.0-py3-none-any.whl", hash = "sha256:e1f4981c1152eec7c2490a3401142292cc47d686194188416db2648fdfe1d040", size = 34026, upload-time = "2026-04-14T04:59:02.602Z" }, -] - [[package]] name = "mkdocs" version = "1.6.1" @@ -3372,15 +3352,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/59/af/e6618858bd8f9be6c58ea7238b7ec224d1a31df506dd912c41672fe4f369/pyobjc_framework_vision-12.2.1-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:e288cb41349d6e84cfac0822a0c1fb476bf5fa094913b19e8c2899e90a1a9e8f", size = 17105, upload-time = "2026-06-19T16:19:24.817Z" }, ] -[[package]] -name = "pyparsing" -version = "3.3.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f3/91/9c6ee907786a473bf81c5f53cf703ba0957b23ab84c264080fb5a450416f/pyparsing-3.3.2.tar.gz", hash = "sha256:c777f4d763f140633dcb6d8a3eda953bf7a214dc4eff598413c070bcdc117cbc", size = 6851574, upload-time = "2026-01-21T03:57:59.36Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/10/bd/c038d7cc38edc1aa5bf91ab8068b63d4308c66c4c8bb3cbba7dfbc049f9c/pyparsing-3.3.2-py3-none-any.whl", hash = "sha256:850ba148bd908d7e2411587e247a1e4f0327839c40e2e5e6d05a007ecc69911d", size = 122781, upload-time = "2026-01-21T03:57:55.912Z" }, -] - [[package]] name = "pypdfium2" version = "5.8.0" @@ -4553,15 +4524,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f8/ba/d69adbe699b768f6b29a5eec7b47dd610bd17a69de51b251126a801369ea/uvloop-0.22.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:1f38ec5e3f18c8a10ded09742f7fb8de0108796eb673f30ce7762ce1b8550cad", size = 4239051, upload-time = "2025-10-16T22:16:43.224Z" }, ] -[[package]] -name = "verspec" -version = "0.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e7/44/8126f9f0c44319b2efc65feaad589cadef4d77ece200ae3c9133d58464d0/verspec-0.1.0.tar.gz", hash = "sha256:c4504ca697b2056cdb4bfa7121461f5a0e81809255b41c03dda4ba823637c01e", size = 27123, upload-time = "2020-11-30T02:24:09.646Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a4/ce/3b6fee91c85626eaf769d617f1be9d2e15c1cca027bbdeb2e0d751469355/verspec-0.1.0-py3-none-any.whl", hash = "sha256:741877d5633cc9464c45a469ae2a31e801e6dbbaa85b9675d481cda100f11c31", size = 19640, upload-time = "2020-11-30T02:24:08.387Z" }, -] - [[package]] name = "watchdog" version = "6.0.0" From 71b524791e0a2e8af20c0c9eb02f09156b476536 Mon Sep 17 00:00:00 2001 From: Jakub Walaszczyk Date: Thu, 25 Jun 2026 15:28:14 +0200 Subject: [PATCH 08/16] Revert "feature: Update prompt templates (#75)" This reverts commit 5217636a3a7db26001d43332ce23bf8965da7fc9. Signed-off-by: Lukasz Cmielowski Assisted-by: Cursor --- .../assets_generator/pattern_builder.py | 6 + .../rag_templates_optimization.py | 42 +++ .../optimization/search_space_preparation.py | 163 +++++++++++ ai4rag/rag/foundation_models/base_model.py | 6 +- ai4rag/rag/foundation_models/ogx.py | 2 - ai4rag/rag/foundation_models/utils.py | 30 +- ai4rag/rag/template/simple_rag_template.py | 5 +- ai4rag/search_space/src/model_props.py | 145 ++++------ dev_utils/mocks.py | 2 - .../assets_generator/test_pattern_builder.py | 13 + .../optimization/test_language_detection.py | 256 ++++++++++++++++++ .../optimization/test_rag_optimization.py | 105 +++++++ .../optimization/test_search_space_prep.py | 26 ++ .../rag/foundation_models/test_base_model.py | 18 -- .../ai4rag/rag/foundation_models/test_ogx.py | 2 +- .../rag/foundation_models/test_utils.py | 14 - .../rag/template/test_simple_rag_template.py | 18 -- .../search_space/src/test_model_props.py | 91 ------- 18 files changed, 682 insertions(+), 262 deletions(-) create mode 100644 tests/unit/ai4rag/components/optimization/test_language_detection.py delete mode 100644 tests/unit/ai4rag/search_space/src/test_model_props.py diff --git a/ai4rag/components/assets_generator/pattern_builder.py b/ai4rag/components/assets_generator/pattern_builder.py index 9d1c848..6babc84 100644 --- a/ai4rag/components/assets_generator/pattern_builder.py +++ b/ai4rag/components/assets_generator/pattern_builder.py @@ -4,6 +4,7 @@ # ----------------------------------------------------------------------------- def build_pattern_json( pattern: dict, + detected_language: dict | None = None, ) -> dict: """Update pattern information with detected language and responses template. @@ -12,12 +13,17 @@ def build_pattern_json( pattern : dict A single evaluation result object carrying ``indexing_params``, ``rag_params``, ``pattern_name``, ``collection``, etc. + detected_language : dict | None, default=None + Language detection result (``{"code": "...", "name": "..."}``). Returns ------- dict Pattern definition suitable for JSON serialisation. """ + if detected_language: + pattern["settings"]["generation"]["detected_language"] = detected_language + pattern["settings"]["responses_template"] = { "model": pattern["settings"]["generation"]["model_id"], "stream": False, diff --git a/ai4rag/components/optimization/rag_templates_optimization.py b/ai4rag/components/optimization/rag_templates_optimization.py index cccf42c..a967147 100644 --- a/ai4rag/components/optimization/rag_templates_optimization.py +++ b/ai4rag/components/optimization/rag_templates_optimization.py @@ -131,6 +131,8 @@ def run_rag_optimization( # pylint: disable=too-many-locals,too-many-arguments, with open(search_space_report_path, "r", encoding="utf-8") as f: search_space_raw = yml.safe_load(f) + detected_language: dict[str, str] | None = search_space_raw.pop("detected_language", None) + search_space = AI4RAGSearchSpace( params=[Parameter(param, "C", values=values) for param, values in search_space_raw.items()] ) @@ -145,6 +147,9 @@ def run_rag_optimization( # pylint: disable=too-many-locals,too-many-arguments, benchmark_data = pd.read_json(Path(test_data_path)) + if detected_language: + _inject_language_instructions(search_space, detected_language) + rag_exp = AI4RAGExperiment( client=ogx_client, event_handler=event_handler, @@ -174,6 +179,7 @@ def run_rag_optimization( # pylint: disable=too-many-locals,too-many-arguments, pattern_data = build_pattern_json( pattern=pattern.get("payload"), + detected_language=detected_language, ) # Generate notebooks @@ -270,6 +276,42 @@ def _evaluation_result_fallback(eval_data_list: list, evaluation_result: Any) -> return out +def _inject_language_instructions( + search_space: AI4RAGSearchSpace, + detected_language: dict[str, str], +) -> None: + """Inject explicit language response instructions into foundation models. + + When a non-English language is detected, each foundation model's system + and user messages are augmented with an instruction to respond in that + language. + + Parameters + ---------- + search_space + The search space whose foundation models will be modified in-place. + detected_language + Detection result with ``"code"`` and ``"name"`` keys. + """ + lang_code = detected_language.get("code", "") + lang_name = detected_language.get("name", "") + if not lang_name or lang_code == "en": + return + + explicit_instruction = f"You MUST respond in {lang_name}." + for fm in search_space["foundation_model"].values: + existing_sys = fm.system_message_text + existing_usr = str(fm.user_message_text) + fm.system_message_text = f"{explicit_instruction} {existing_sys}" + fm.user_message_text = f"{existing_usr}{explicit_instruction}" + + _logger.info( + "Set explicit language instruction on %d foundation model(s): %s", + len(search_space["foundation_model"].values), + explicit_instruction, + ) + + def _validate_optimization_settings(optimization_settings: dict | None) -> dict: """Validate and normalize optimization settings. diff --git a/ai4rag/components/optimization/search_space_preparation.py b/ai4rag/components/optimization/search_space_preparation.py index f30295e..2561023 100644 --- a/ai4rag/components/optimization/search_space_preparation.py +++ b/ai4rag/components/optimization/search_space_preparation.py @@ -30,6 +30,29 @@ _DEFAULT_SAMPLE_SIZE = 5 _DEFAULT_SEED = 17 +LANGUAGE_MAP: dict[str, str] = { + "ja": "Japanese", + "ko": "Korean", + "zh-cn": "Chinese", + "zh-tw": "Chinese", + "en": "English", + "de": "German", + "fr": "French", + "es": "Spanish", + "pt": "Portuguese", + "it": "Italian", + "ru": "Russian", + "ar": "Arabic", + "hi": "Hindi", + "th": "Thai", + "vi": "Vietnamese", + "pl": "Polish", + "nl": "Dutch", + "sv": "Swedish", + "cs": "Czech", + "tr": "Turkish", +} + def _represent_model_instance(dumper: yml.Dumper, model: BaseFoundationModel | BaseEmbeddingModel) -> yml.Node: """Instruct :mod:`yaml` on how to serialize model instances under a ``!Model`` tag. @@ -71,10 +94,14 @@ class SearchSpaceReport: model lists and non-model parameter ranges. selected_models : dict[str, list] Foundation and embedding model lists that survived pre-selection. + detected_language : dict[str, str] | None + Detected language code and name, or ``None`` when English or when + detection was not performed. """ search_space: dict[str, Any] selected_models: dict[str, list] + detected_language: dict[str, str] | None def save_yaml(self, path: str | Path) -> None: """Serialize the report to a YAML file. @@ -87,6 +114,8 @@ def save_yaml(self, path: str | Path) -> None: Destination file path. """ report = dict(self.search_space) + if self.detected_language: + report["detected_language"] = self.detected_language path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) @@ -170,6 +199,10 @@ def prepare_search_space_report( # pylint: disable=too-many-locals,too-many-arg # Load benchmark data and documents benchmark_df = pd.read_json(Path(test_data_path)) + detected_language = _detect_benchmark_language( + benchmark_df, llm_client=ogx_client, generation_models=generation_models + ) + benchmark_data = BenchmarkData(benchmark_df) documents = load_docling_documents(extracted_text_path) @@ -211,6 +244,7 @@ def prepare_search_space_report( # pylint: disable=too-many-locals,too-many-arg return SearchSpaceReport( search_space=verbose_repr, selected_models=selected_models, + detected_language=detected_language, ) @@ -223,3 +257,132 @@ def _validate_model_list(models: list[str] | None, name: str) -> None: for i, m in enumerate(models): if not m: raise TypeError(f"{name}[{i}] must be a non-empty string.") + + +def _detect_language_via_llm( # pylint: disable=too-many-locals + questions: list[str], + llm_client: OgxClient, + allowed_generation_models: list[str] | None = None, +) -> dict[str, str] | None: + """Detect the dominant language from sample questions using an LLM. + + Sends a small sample of questions to a generation model registered in OGX + and asks it to return the ISO 639-1 code. Models listed in + *allowed_generation_models* are preferred when available. + + Parameters + ---------- + questions + Raw question texts to classify. Only the first five are sent to the + model. + llm_client + An authenticated :class:`OgxClient` instance. + allowed_generation_models + Optional whitelist of model identifiers to prefer. + + Returns + ------- + dict[str, str] | None + A dictionary with ``code`` and ``name`` keys when a non-English + language is detected, or ``None`` for English / on failure. + """ + sample_text = "\n".join(f"- {q}" for q in questions[:5]) + valid_codes = ", ".join(sorted(LANGUAGE_MAP.keys())) + + try: + models_response = llm_client.models.list() + models_list = models_response.data if hasattr(models_response, "data") else list(models_response) + registered_ids = {(m.identifier if hasattr(m, "identifier") else str(m.id)) for m in models_list} + + model_id: str | None = None + if allowed_generation_models: + for gm in allowed_generation_models: + if gm in registered_ids: + model_id = gm + break + if not model_id: + for m in models_list: + if hasattr(m, "model_type") and getattr(m, "model_type", "") == "llm": + model_id = m.identifier if hasattr(m, "identifier") else str(m.id) + break + if not model_id: + _logger.warning("No models available for LLM language detection.") + return None + + response = llm_client.chat.completions.create( + model=model_id, + messages=[ + { + "role": "system", + "content": ( + "You are a language detection assistant. " + "Given text samples, respond with ONLY the ISO 639-1 language code " + f"(one of: {valid_codes}). " + "Nothing else — just the code." + ), + }, + { + "role": "user", + "content": f"What language are these questions written in?\n{sample_text}", + }, + ], + max_completion_tokens=10, + temperature=0.0, + ) + if not response.choices: + _logger.warning("LLM returned empty choices for language detection.") + return None + raw = response.choices[0].message.content.strip().lower().replace('"', "").replace("'", "") + detected_code = raw.split()[0] if raw else None + + if not detected_code: + return None + + name = LANGUAGE_MAP.get(detected_code) + if not name: + _logger.warning("LLM returned unsupported language code: %s", detected_code) + return None + + _logger.info("Language detected via LLM: %s (%s)", detected_code, name) + return {"code": detected_code, "name": name} + + except Exception as exc: + _logger.warning("LLM language detection failed: %s", exc) + return None + + +def _detect_benchmark_language( + benchmark_df: pd.DataFrame, + llm_client: OgxClient, + generation_models: list[str] | None = None, + sample_size: int = 10, +) -> dict[str, str] | None: + """Detect the dominant language from benchmark question data. + + Extracts up to *sample_size* questions from the ``question`` column and + delegates to :func:`detect_language_via_llm` for classification. + + Parameters + ---------- + benchmark_df + DataFrame with a ``question`` column. + llm_client + An authenticated :class:`OgxClient` instance. + generation_models + Optional whitelist of model identifiers passed through to the LLM + detection step. + sample_size + Maximum number of questions to sample. + + Returns + ------- + dict[str, str] | None + A dictionary with ``code`` and ``name`` keys when a non-English + language is detected, or ``None`` for English / on failure. + """ + questions = benchmark_df["question"].dropna().astype(str).tolist() + if not questions: + return None + + sample = questions[:sample_size] + return _detect_language_via_llm(sample, llm_client, allowed_generation_models=generation_models) diff --git a/ai4rag/rag/foundation_models/base_model.py b/ai4rag/rag/foundation_models/base_model.py index 3e381c8..50c83ce 100644 --- a/ai4rag/rag/foundation_models/base_model.py +++ b/ai4rag/rag/foundation_models/base_model.py @@ -37,17 +37,13 @@ def __init__( system_message_text: str | None = None, user_message_text: str | None = None, context_template_text: str | None = None, - language_autodetect: bool = False, ): self.client = client self.model_id = model_id self.params = params - self.language_autodetect = language_autodetect self.system_message_text = system_message_text or get_system_message_text(model_name=model_id) self.user_message_text = ( - user_message_text - if user_message_text is not None - else get_user_message_text(model_name=model_id, language_autodetect=language_autodetect) + user_message_text if user_message_text is not None else get_user_message_text(model_name=model_id) ) self.context_template_text = ( context_template_text diff --git a/ai4rag/rag/foundation_models/ogx.py b/ai4rag/rag/foundation_models/ogx.py index a04db14..51945d8 100644 --- a/ai4rag/rag/foundation_models/ogx.py +++ b/ai4rag/rag/foundation_models/ogx.py @@ -32,7 +32,6 @@ def __init__( system_message_text: str | None = None, user_message_text: str | None = None, context_template_text: str | None = None, - language_autodetect: bool = False, ): super().__init__( @@ -42,7 +41,6 @@ def __init__( system_message_text=system_message_text, user_message_text=user_message_text, context_template_text=context_template_text, - language_autodetect=language_autodetect, ) @property diff --git a/ai4rag/rag/foundation_models/utils.py b/ai4rag/rag/foundation_models/utils.py index a19f272..1aa0b9b 100644 --- a/ai4rag/rag/foundation_models/utils.py +++ b/ai4rag/rag/foundation_models/utils.py @@ -7,7 +7,6 @@ from ai4rag.search_space.src.model_props import ( CONTEXT_TEXT_PLACEHOLDER, - DOCUMENT_NUMBER_PLACEHOLDER, QUESTION_PLACEHOLDER, REFERENCE_DOCUMENTS_PLACEHOLDER, ) @@ -30,12 +29,11 @@ def __init__( super().__init__() self.template_name = template_name - if template_name == "context_template_text": - self._required_placeholders: tuple[str, ...] = (CONTEXT_TEXT_PLACEHOLDER,) - self._optional_placeholders: tuple[str, ...] = (DOCUMENT_NUMBER_PLACEHOLDER,) - else: - self._required_placeholders = (QUESTION_PLACEHOLDER, REFERENCE_DOCUMENTS_PLACEHOLDER) - self._optional_placeholders = () + self._required_placeholders: tuple[str, ...] = ( + (CONTEXT_TEXT_PLACEHOLDER,) + if template_name == "context_template_text" + else (QUESTION_PLACEHOLDER, REFERENCE_DOCUMENTS_PLACEHOLDER) + ) def validate(self, _: object, value: T) -> T: """ @@ -65,20 +63,18 @@ def validate(self, _: object, value: T) -> T: raise TypeError(f"Expected {value!r} to be a str or None.") placeholders_count = 0 - allowed = self._required_placeholders + self._optional_placeholders for _, field_name, _, _ in Formatter().parse(value): if field_name is None: # when there is text NOT followed by a placeholder template continue - if field_name not in allowed: + if field_name not in self._required_placeholders: raise ConstraintsValidationError( f"Custom {field_name.split('_')[0]} template text got unexpected placeholder `{field_name}`, " - f"valid placeholders are `{allowed}`." + f"valid placeholders are `{self._required_placeholders}`." ) - if field_name in self._required_placeholders: - placeholders_count += 1 + placeholders_count += 1 if placeholders_count != len(self._required_placeholders): raise ConstraintsValidationError( @@ -118,28 +114,24 @@ def _validate_prompt_templates_placeholders( """ if template_name == "context_template_text": required_placeholders = (CONTEXT_TEXT_PLACEHOLDER,) - optional_placeholders = (DOCUMENT_NUMBER_PLACEHOLDER,) elif template_name == "user_message_text": required_placeholders = (QUESTION_PLACEHOLDER, REFERENCE_DOCUMENTS_PLACEHOLDER) - optional_placeholders = () else: raise ValueError(f"Cannot validate presence of expected template placeholders on field: {template_name}") placeholders_count = 0 - allowed = required_placeholders + optional_placeholders for _, field_name, _, _ in Formatter().parse(template_str): if field_name is None: # when there is text NOT followed by a placeholder template continue - if field_name not in allowed: + if field_name not in required_placeholders: raise ValueError( f"Custom {field_name.split('_')[0]} template text got unexpected placeholder `{field_name}`, " - f"valid placeholders are `{allowed}`." + f"valid placeholders are `{required_placeholders}`." ) - if field_name in required_placeholders: - placeholders_count += 1 + placeholders_count += 1 if placeholders_count != len(required_placeholders): raise ValueError( diff --git a/ai4rag/rag/template/simple_rag_template.py b/ai4rag/rag/template/simple_rag_template.py index df73018..cd5b0ac 100644 --- a/ai4rag/rag/template/simple_rag_template.py +++ b/ai4rag/rag/template/simple_rag_template.py @@ -95,9 +95,8 @@ def generate(self, question: str, **kwargs) -> dict[str, Any]: """ reference_documents = self.retriever.retrieve(question, **kwargs) - context = "\n\n".join( - self.foundation_model.context_template_text.format(document=chunk.text, doc_number=doc_number) - for doc_number, chunk in enumerate(reference_documents, start=1) + context = "\n".join( + [self.foundation_model.context_template_text.format(document=chunk.text) for chunk in reference_documents] ) user_message = self.foundation_model.user_message_text.format( diff --git a/ai4rag/search_space/src/model_props.py b/ai4rag/search_space/src/model_props.py index 1a07e10..e853ed0 100644 --- a/ai4rag/search_space/src/model_props.py +++ b/ai4rag/search_space/src/model_props.py @@ -9,7 +9,6 @@ "QUESTION_PLACEHOLDER", "REFERENCE_DOCUMENTS_PLACEHOLDER", "CONTEXT_TEXT_PLACEHOLDER", - "DOCUMENT_NUMBER_PLACEHOLDER", "MULTILINGUAL_SUPPORT_INSTRUCTION_PLACEHOLDER", ] @@ -17,65 +16,43 @@ QUESTION_PLACEHOLDER = "question" REFERENCE_DOCUMENTS_PLACEHOLDER = "reference_documents" CONTEXT_TEXT_PLACEHOLDER = "document" -DOCUMENT_NUMBER_PLACEHOLDER = "doc_number" MULTILINGUAL_SUPPORT_INSTRUCTION_PLACEHOLDER = "multilingual_support" +# A mapping from model name into their corresponding prompt templates. +# The parameters for the prompt templates are QUESTION_PLACEHOLDER and REFERENCE_DOCUMENTS_PLACEHOLDER + _MULTILINGUAL_SUPPORT_ENABLED_PROMPT = ( - "You MUST write your entire answer in the same language as the question. " - "Do NOT respond in any other language, even if the documents use a different language. " - "Every word of your answer must match the question's language." + "Respond exclusively in the language of the question, " + "regardless of any other language used in the provided context. " + "Ensure that your entire response is in the same language as the question." ) _MULTILINGUAL_SUPPORT_DISABLED_PROMPT = ( - "You MUST write your entire answer in English only. " - "Do NOT use any other language, even if the question or documents are in another language. " - "Every word of your answer must be in English." -) - - -_RAG_GROUNDING_INSTRUCTION = ( - "Answer ONLY using information from the documents below. " - "Do not use outside knowledge. " - "If the documents do not contain the answer, say you do not have enough information." -) - - -_RAG_CITATION_INSTRUCTION = ( - "You MUST cite sources using [1], [2], etc. matching the document numbers for every factual claim." + "Respond exclusively in English, " + "regardless of the language of the question or any other language used in the provided context. " + "Ensure that your entire response is in English only." ) -_RAG_ANSWER_LENGTH_GUIDANCE = "max 150 words" - - -_RAG_ANSWER_PROMPT_LINE = f"Answer ({_RAG_ANSWER_LENGTH_GUIDANCE}, with citations):\n" - - -_RAG_SYSTEM_PREFIX = "You are a retrieval-augmented assistant. Answer using ONLY the provided documents. " - - -_DEFAULT_NUMBERED_CONTEXT_TEMPLATE = f"Document {{{DOCUMENT_NUMBER_PLACEHOLDER}}}:\n{{{CONTEXT_TEXT_PLACEHOLDER}}}\n" - - _DEFAULT_SYSTEM_MESSAGE_TEXT = ( - f"{_RAG_SYSTEM_PREFIX}" "If the question is unanswerable from the documents, say you cannot answer." + "Please answer the question I provide in the Question section below, " + "based solely on the information I provide in the Context section. " + "If the question is unanswerable, please say you cannot answer." ) _DEFAULT_USER_MESSAGE_TEXT = ( - f"{_RAG_GROUNDING_INSTRUCTION}\n" - f"{_RAG_CITATION_INSTRUCTION}\n\n" - f"Documents:\n{{{REFERENCE_DOCUMENTS_PLACEHOLDER}}}\n\n" - f"Question: {{{QUESTION_PLACEHOLDER}}}\n\n" - f"{_RAG_ANSWER_PROMPT_LINE}" - f"{{{MULTILINGUAL_SUPPORT_INSTRUCTION_PLACEHOLDER}}}\n" + f"\n\nContext:\n{{{REFERENCE_DOCUMENTS_PLACEHOLDER}}}:\n\n" + f"Question: {{{QUESTION_PLACEHOLDER}}}. \n" + "Again, please answer the question based on the context provided only. If the context is not related to " + "the question, just say you cannot answer. " + f"{{{MULTILINGUAL_SUPPORT_INSTRUCTION_PLACEHOLDER}}}" ) _DEFAULT_GRANITE_SYSTEM_MESSAGE_TEXT = ( - f"{_RAG_SYSTEM_PREFIX}" "You are Granite Chat, an AI language model developed by IBM. " "You are a cautious assistant. You carefully follow instructions. " "You are helpful and harmless and you follow ethical guidelines and promote positive behaviour." @@ -83,39 +60,40 @@ _DEFAULT_GRANITE_USER_MESSAGE_TEXT = ( - f"{_RAG_GROUNDING_INSTRUCTION}\n" - f"{_RAG_CITATION_INSTRUCTION}\n\n" - "You are a specialized Retrieval Augmented Generation (RAG) assistant. " - "Prioritize correctness and ensure your response is grounded in the documents.\n\n" - f"Documents:\n{{{REFERENCE_DOCUMENTS_PLACEHOLDER}}}\n\n" - f"Question: {{{QUESTION_PLACEHOLDER}}}\n\n" - f"{_RAG_ANSWER_PROMPT_LINE}" - f"{{{MULTILINGUAL_SUPPORT_INSTRUCTION_PLACEHOLDER}}}\n" + "You are an AI language model designed to function as a specialized Retrieval Augmented Generation (RAG) " + "assistant. When generating responses, prioritize correctness, i.e., ensure that your response is grounded in " + "context and user query. Always make sure that your response is relevant to the question. " + "\n" + "Answer Length: detailed" + "\n" + f"{{{REFERENCE_DOCUMENTS_PLACEHOLDER}}}" + "\n" + f"{{{MULTILINGUAL_SUPPORT_INSTRUCTION_PLACEHOLDER}}}" + "\n" + f"{{{QUESTION_PLACEHOLDER}}} " + "\n" + "\n" ) _DEFAULT_LLAMA_SYSTEM_MESSAGE_TEXT = ( - f"{_RAG_SYSTEM_PREFIX}" "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. " "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. " "Please ensure that your responses are socially unbiased and positive in nature.\n" "If a question does not make any sense, or is not factually coherent, explain why instead of answering " - "something not correct. If you don't know the answer to a question, please don't share false information.\n" + "something not correct. If you don’t know the answer to a question, please don’t share false information.\n" ) _DEFAULT_LLAMA_USER_MESSAGE_TEXT = ( - f"{_RAG_GROUNDING_INSTRUCTION}\n" - f"{_RAG_CITATION_INSTRUCTION}\n\n" - f"Documents:\n{{{REFERENCE_DOCUMENTS_PLACEHOLDER}}}\n\n" - f"Question: {{{QUESTION_PLACEHOLDER}}}\n\n" - f"{_RAG_ANSWER_PROMPT_LINE}" + f"{{{REFERENCE_DOCUMENTS_PLACEHOLDER}}}\n" + f"[conversation]: {{{QUESTION_PLACEHOLDER}}}. Answer with no more than 150 words. If you cannot base your " + "answer on the given document, please state that you do not have an answer. " f"{{{MULTILINGUAL_SUPPORT_INSTRUCTION_PLACEHOLDER}}}\n" ) _DEFAULT_MISTRAL_SYSTEM_MESSAGE_TEXT = ( - f"{_RAG_SYSTEM_PREFIX}" "You are a helpful, respectful and honest assistant. " "Always answer as helpfully as possible, while being safe. " "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. " @@ -126,34 +104,32 @@ _DEFAULT_MISTRAL_USER_MESSAGE_TEXT = ( - f"{_RAG_GROUNDING_INSTRUCTION}\n" - f"{_RAG_CITATION_INSTRUCTION}\n\n" - f"Documents:\n{{{REFERENCE_DOCUMENTS_PLACEHOLDER}}}\n\n" - f"Question: {{{QUESTION_PLACEHOLDER}}}\n\n" - f"{_RAG_ANSWER_PROMPT_LINE}" - f"{{{MULTILINGUAL_SUPPORT_INSTRUCTION_PLACEHOLDER}}}\n" + "Generate the next agent response by answering the question. You are provided several documents with titles. " + "If the answer comes from different documents please mention all possibilities and use the titles of documents " + "to separate between topics or domains. If you cannot base your answer on the given documents, " + f"please state that you do not have an answer. " + f"{{{REFERENCE_DOCUMENTS_PLACEHOLDER}}}\n\n" + f"{{{MULTILINGUAL_SUPPORT_INSTRUCTION_PLACEHOLDER}}}\n\n" + f"{{{QUESTION_PLACEHOLDER}}}" ) _DEFAULT_OPENAI_SYSTEM_MESSAGE_TEXT = ( - f"{_RAG_SYSTEM_PREFIX}" + "You are a AI language model designed to function as a specialized Retrieval Augmented Generation (RAG) assistant. " "When generating responses, prioritize correctness, i.e., ensure that your response is correct given the context " "and user query, and that it is grounded in the context. " "Furthermore, make sure that the response is supported by the given document or context. " "When the question cannot be answered using the context or document, output the following response: " "'I am sorry, I do not have the information you are looking for in my knowledge base.'. " "Always make sure that your response is relevant to the question. If an explanation is needed, " - "first provide the explanation or reasoning, and then give the final answer.\n\n" + "first provide the explanation or reasoning, and then give the final answer.\nAnswer Length: concise.\n\n" ) _DEFAULT_OPENAI_USER_MESSAGE_TEXT = ( - f"{_RAG_GROUNDING_INSTRUCTION}\n" - f"{_RAG_CITATION_INSTRUCTION}\n\n" - f"Documents:\n{{{REFERENCE_DOCUMENTS_PLACEHOLDER}}}\n\n" - f"Question: {{{QUESTION_PLACEHOLDER}}}\n\n" - f"{_RAG_ANSWER_PROMPT_LINE}" - f"{{{MULTILINGUAL_SUPPORT_INSTRUCTION_PLACEHOLDER}}}\n" + f"[Document]\n{{{REFERENCE_DOCUMENTS_PLACEHOLDER}}}\n[End]\n" + f"{{{QUESTION_PLACEHOLDER}}}. \n" + f"{{{MULTILINGUAL_SUPPORT_INSTRUCTION_PLACEHOLDER}}}" ) @@ -185,11 +161,12 @@ } -_DEFAULT_GRANITE_CONTEXT_TEMPLATE = _DEFAULT_NUMBERED_CONTEXT_TEMPLATE -_DEFAULT_LLAMA_CONTEXT_TEMPLATE = _DEFAULT_NUMBERED_CONTEXT_TEMPLATE -_DEFAULT_MISTRAL_CONTEXT_TEMPLATE = _DEFAULT_NUMBERED_CONTEXT_TEMPLATE -_DEFAULT_OPENAI_CONTEXT_TEMPLATE = _DEFAULT_NUMBERED_CONTEXT_TEMPLATE -_DEFAULT_CONTEXT_TEMPLATE = _DEFAULT_NUMBERED_CONTEXT_TEMPLATE +# A mapping from model names into their corresponding context template texts. These templates describe how each +# retrieved context is to be wrapped, before being integrated into a full RAG prompt text. +# The parameter for the context template text is CONTEXT_TEXT_PLACEHOLDER +_DEFAULT_GRANITE_CONTEXT_TEMPLATE = f"[Document]\n{{{CONTEXT_TEXT_PLACEHOLDER}}}\n[End]" +_DEFAULT_LLAMA_CONTEXT_TEMPLATE = f"[document]: {{{CONTEXT_TEXT_PLACEHOLDER}}}\n" +_DEFAULT_CONTEXT_TEMPLATE = f"{{{CONTEXT_TEXT_PLACEHOLDER}}}" _model_name_to_context_template_text = { "meta-llama/llama-3-1-70b-instruct": _DEFAULT_LLAMA_CONTEXT_TEMPLATE, @@ -198,10 +175,7 @@ "meta-llama/llama-4-maverick-17b-128e-instruct-fp8": _DEFAULT_LLAMA_CONTEXT_TEMPLATE, "ibm/granite-3-8b-instruct": _DEFAULT_GRANITE_CONTEXT_TEMPLATE, "ibm/granite-3-3-8b-instruct": _DEFAULT_GRANITE_CONTEXT_TEMPLATE, - "mistralai/mistral-small-3-1-24b-instruct-2503": _DEFAULT_MISTRAL_CONTEXT_TEMPLATE, - "mistralai/mistral-medium-2505": _DEFAULT_MISTRAL_CONTEXT_TEMPLATE, - "mistralai/mistral-large": _DEFAULT_MISTRAL_CONTEXT_TEMPLATE, - "openai/gpt-oss-120b": _DEFAULT_OPENAI_CONTEXT_TEMPLATE, + "openai/gpt-oss-120b": _DEFAULT_CONTEXT_TEMPLATE, } @@ -209,8 +183,8 @@ def get_context_template_text(model_name: str) -> str: """ Get a model-specific context template text. - The context template text is a template with placeholders ``document`` and, - optionally, ``doc_number``. + The context template text is a template with one placeholder: "context_text". + This field should be populated before use within a RAG prompt. Parameters ---------- @@ -229,8 +203,6 @@ def get_context_template_text(model_name: str) -> str: context_template = _DEFAULT_GRANITE_CONTEXT_TEMPLATE elif "llama" in model_name: context_template = _DEFAULT_LLAMA_CONTEXT_TEMPLATE - elif "mistral" in model_name: - context_template = _DEFAULT_MISTRAL_CONTEXT_TEMPLATE else: context_template = _DEFAULT_CONTEXT_TEMPLATE @@ -261,15 +233,13 @@ def get_system_message_text(model_name: str) -> str: system_message_text = _DEFAULT_LLAMA_SYSTEM_MESSAGE_TEXT elif "mistral" in model_name: system_message_text = _DEFAULT_MISTRAL_SYSTEM_MESSAGE_TEXT - elif "openai" in model_name or "gpt" in model_name: - system_message_text = _DEFAULT_OPENAI_SYSTEM_MESSAGE_TEXT else: system_message_text = _DEFAULT_SYSTEM_MESSAGE_TEXT return system_message_text -def get_user_message_text(model_name: str, language_autodetect: bool = False) -> str: +def get_user_message_text(model_name: str, language_autodetect: bool = True) -> str: """ Get a model-specific prompt text. @@ -283,7 +253,6 @@ def get_user_message_text(model_name: str, language_autodetect: bool = False) -> language_autodetect : bool If True, language of the question will be automatically detected. - Defaults to False (English-only responses) for stronger faithfulness on English benchmarks. Returns ------- @@ -299,8 +268,6 @@ def get_user_message_text(model_name: str, language_autodetect: bool = False) -> user_message_text = _DEFAULT_LLAMA_USER_MESSAGE_TEXT elif "mistral" in model_name: user_message_text = _DEFAULT_MISTRAL_USER_MESSAGE_TEXT - elif "openai" in model_name or "gpt" in model_name: - user_message_text = _DEFAULT_OPENAI_USER_MESSAGE_TEXT else: user_message_text = _DEFAULT_USER_MESSAGE_TEXT diff --git a/dev_utils/mocks.py b/dev_utils/mocks.py index 5a6986e..c46543c 100644 --- a/dev_utils/mocks.py +++ b/dev_utils/mocks.py @@ -20,7 +20,6 @@ def __init__( system_message_text: str | None = None, user_message_text: str | None = None, context_template_text: str | None = None, - language_autodetect: bool = False, ): super().__init__( client=client, @@ -29,7 +28,6 @@ def __init__( system_message_text=system_message_text, user_message_text=user_message_text, context_template_text=context_template_text, - language_autodetect=language_autodetect, ) def chat(self, messages: list[MessageTyped]) -> list[MessageTyped]: diff --git a/tests/unit/ai4rag/assets_generator/test_pattern_builder.py b/tests/unit/ai4rag/assets_generator/test_pattern_builder.py index 3752333..c4a0ee4 100644 --- a/tests/unit/ai4rag/assets_generator/test_pattern_builder.py +++ b/tests/unit/ai4rag/assets_generator/test_pattern_builder.py @@ -97,6 +97,19 @@ def test_returns_same_dict(self): result = build_pattern_json(pattern) assert result is pattern + def test_no_detected_language_by_default(self): + """When detected_language is None, no detected_language key appears in generation.""" + pattern = _make_pattern() + build_pattern_json(pattern) + assert "detected_language" not in pattern["settings"]["generation"] + + def test_detected_language_injected(self): + """Non-English language detection must inject detected_language into generation.""" + pattern = _make_pattern() + lang = {"code": "de", "name": "German"} + build_pattern_json(pattern, detected_language=lang) + assert pattern["settings"]["generation"]["detected_language"] == lang + def test_hybrid_rrf_ranking_options(self): """Hybrid search with RRF ranker must set ranker and impact_factor in ranking_options.""" pattern = _make_pattern() diff --git a/tests/unit/ai4rag/components/optimization/test_language_detection.py b/tests/unit/ai4rag/components/optimization/test_language_detection.py new file mode 100644 index 0000000..5447736 --- /dev/null +++ b/tests/unit/ai4rag/components/optimization/test_language_detection.py @@ -0,0 +1,256 @@ +# ----------------------------------------------------------------------------- +# Copyright IBM Corp. 2026 +# SPDX-License-Identifier: Apache-2.0 +# ----------------------------------------------------------------------------- +from __future__ import annotations + +from unittest.mock import MagicMock + +import pandas as pd +import pytest + +from ai4rag.components.optimization.search_space_preparation import ( + LANGUAGE_MAP, +) +from ai4rag.components.optimization.search_space_preparation import ( + _detect_benchmark_language as detect_benchmark_language, +) +from ai4rag.components.optimization.search_space_preparation import _detect_language_via_llm as detect_language_via_llm + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def mock_ogx_client() -> MagicMock: + """Return a MagicMock that behaves like an OgxClient with one LLM model.""" + mock_model = MagicMock() + mock_model.identifier = "test-model" + mock_model.model_type = "llm" + + mock_models_response = MagicMock() + mock_models_response.data = [mock_model] + + mock_choice = MagicMock() + mock_choice.message.content = "ja" + mock_response = MagicMock() + mock_response.choices = [mock_choice] + + mock_client = MagicMock() + mock_client.models.list.return_value = mock_models_response + mock_client.chat.completions.create.return_value = mock_response + return mock_client + + +@pytest.fixture() +def sample_questions() -> list[str]: + """Return a short list of sample questions.""" + return [ + "東京の天気はどうですか?", + "日本の首都はどこですか?", + "富士山の高さは?", + ] + + +# --------------------------------------------------------------------------- +# LANGUAGE_MAP +# --------------------------------------------------------------------------- + + +class TestLanguageMap: + """Verify the static LANGUAGE_MAP contents.""" + + def test_known_codes_present(self): + """Well-known ISO 639-1 codes must resolve to their language names.""" + assert LANGUAGE_MAP["ja"] == "Japanese" + assert LANGUAGE_MAP["en"] == "English" + assert LANGUAGE_MAP["pl"] == "Polish" + assert LANGUAGE_MAP["de"] == "German" + assert LANGUAGE_MAP["fr"] == "French" + assert LANGUAGE_MAP["ko"] == "Korean" + + def test_chinese_variants(self): + """Both simplified and traditional Chinese codes must be present.""" + assert LANGUAGE_MAP["zh-cn"] == "Chinese" + assert LANGUAGE_MAP["zh-tw"] == "Chinese" + + def test_all_values_are_nonempty_strings(self): + """Every value in LANGUAGE_MAP must be a non-empty human-readable name.""" + for code, name in LANGUAGE_MAP.items(): + assert isinstance(code, str) and code, f"Invalid code: {code!r}" + assert isinstance(name, str) and name, f"Invalid name for {code!r}: {name!r}" + + +# --------------------------------------------------------------------------- +# detect_language_via_llm +# --------------------------------------------------------------------------- + + +class TestDetectLanguageViaLlm: + """Tests for the LLM-based language detection function.""" + + def test_detects_japanese(self, mock_ogx_client, sample_questions): + """When the LLM returns 'ja', the result must contain the correct code and name.""" + result = detect_language_via_llm(sample_questions, mock_ogx_client) + + assert result is not None + assert result == {"code": "ja", "name": "Japanese"} + mock_ogx_client.chat.completions.create.assert_called_once() + + def test_english_returns_none(self, mock_ogx_client, sample_questions): + """English is the default language, so detection must return None.""" + mock_ogx_client.chat.completions.create.return_value.choices[0].message.content = "en" + + result = detect_language_via_llm(sample_questions, mock_ogx_client) + + # English maps to a valid entry in LANGUAGE_MAP, so it returns the dict. + # The contract says "None for English" only at the detect_benchmark_language + # level. At this level the function returns the mapping when the code is + # valid, regardless of which language it is. + # Re-reading the source: the function returns {"code": ..., "name": ...} + # for ANY valid code, including English. + assert result == {"code": "en", "name": "English"} + + def test_api_failure_returns_none(self, mock_ogx_client, sample_questions): + """An exception from the OGX client must be swallowed, returning None.""" + mock_ogx_client.chat.completions.create.side_effect = RuntimeError("API unavailable") + + result = detect_language_via_llm(sample_questions, mock_ogx_client) + + assert result is None + + def test_unsupported_language_code_returns_none(self, mock_ogx_client, sample_questions): + """An ISO code not present in LANGUAGE_MAP must return None.""" + mock_ogx_client.chat.completions.create.return_value.choices[0].message.content = "xx" + + result = detect_language_via_llm(sample_questions, mock_ogx_client) + + assert result is None + + def test_no_models_available_returns_none(self, sample_questions): + """When no models are registered, the function must return None.""" + mock_model_response = MagicMock() + mock_model_response.data = [] + + mock_client = MagicMock() + mock_client.models.list.return_value = mock_model_response + + result = detect_language_via_llm(sample_questions, mock_client) + + assert result is None + mock_client.chat.completions.create.assert_not_called() + + def test_prefers_allowed_generation_model(self, sample_questions): + """When allowed_generation_models is set, the preferred model must be used.""" + preferred_model = MagicMock() + preferred_model.identifier = "preferred-llm" + preferred_model.model_type = "llm" + + other_model = MagicMock() + other_model.identifier = "other-llm" + other_model.model_type = "llm" + + mock_models_response = MagicMock() + mock_models_response.data = [other_model, preferred_model] + + mock_choice = MagicMock() + mock_choice.message.content = "ja" + mock_response = MagicMock() + mock_response.choices = [mock_choice] + + mock_client = MagicMock() + mock_client.models.list.return_value = mock_models_response + mock_client.chat.completions.create.return_value = mock_response + + detect_language_via_llm(sample_questions, mock_client, allowed_generation_models=["preferred-llm"]) + + call_kwargs = mock_client.chat.completions.create.call_args + assert call_kwargs[1]["model"] == "preferred-llm" or call_kwargs.kwargs["model"] == "preferred-llm" + + def test_samples_at_most_five_questions(self, mock_ogx_client): + """Only the first five questions should appear in the prompt.""" + many_questions = [f"Question {i}" for i in range(20)] + + detect_language_via_llm(many_questions, mock_ogx_client) + + call_kwargs = mock_ogx_client.chat.completions.create.call_args + user_content = call_kwargs[1]["messages"][1]["content"] + # The prompt enumerates "- Q" lines; at most 5 should appear. + assert user_content.count("- Question") == 5 + + def test_empty_llm_response_returns_none(self, mock_ogx_client, sample_questions): + """A blank response from the LLM must return None.""" + mock_ogx_client.chat.completions.create.return_value.choices[0].message.content = " " + + result = detect_language_via_llm(sample_questions, mock_ogx_client) + + assert result is None + + def test_models_list_failure_returns_none(self, sample_questions): + """An exception during models.list() must be swallowed.""" + mock_client = MagicMock() + mock_client.models.list.side_effect = ConnectionError("timeout") + + result = detect_language_via_llm(sample_questions, mock_client) + + assert result is None + + +# --------------------------------------------------------------------------- +# detect_benchmark_language +# --------------------------------------------------------------------------- + + +class TestDetectBenchmarkLanguage: + """Tests for the DataFrame-level language detection wrapper.""" + + def test_detects_language_from_dataframe(self, mock_ogx_client): + """A DataFrame with a 'question' column must yield detection results.""" + df = pd.DataFrame({"question": ["東京の天気は?", "富士山の高さは?", "日本の首都は?"]}) + + result = detect_benchmark_language(df, mock_ogx_client) + + assert result is not None + assert result["code"] == "ja" + mock_ogx_client.chat.completions.create.assert_called_once() + + def test_empty_dataframe_returns_none(self, mock_ogx_client): + """An empty DataFrame must short-circuit to None without calling the LLM.""" + df = pd.DataFrame({"question": pd.Series([], dtype=str)}) + + result = detect_benchmark_language(df, mock_ogx_client) + + assert result is None + mock_ogx_client.chat.completions.create.assert_not_called() + + def test_all_nan_questions_returns_none(self, mock_ogx_client): + """When every question value is NaN, the function must return None.""" + df = pd.DataFrame({"question": [None, None, None]}) + + result = detect_benchmark_language(df, mock_ogx_client) + + assert result is None + mock_ogx_client.chat.completions.create.assert_not_called() + + def test_respects_sample_size(self, mock_ogx_client): + """The sample_size parameter must cap the number of questions forwarded.""" + df = pd.DataFrame({"question": [f"Q{i}" for i in range(50)]}) + + detect_benchmark_language(df, mock_ogx_client, sample_size=3) + + call_kwargs = mock_ogx_client.chat.completions.create.call_args + user_content = call_kwargs[1]["messages"][1]["content"] + # detect_language_via_llm further caps to 5, but sample_size=3 means + # only 3 questions are passed in. + assert user_content.count("- Q") == 3 + + def test_passes_generation_models_through(self, mock_ogx_client): + """The generation_models parameter must reach detect_language_via_llm.""" + df = pd.DataFrame({"question": ["Hello?"]}) + + detect_benchmark_language(df, mock_ogx_client, generation_models=["custom-model"]) + + # The function should still call the LLM; the model selection logic + # inside detect_language_via_llm handles the allowed list. + mock_ogx_client.chat.completions.create.assert_called_once() diff --git a/tests/unit/ai4rag/components/optimization/test_rag_optimization.py b/tests/unit/ai4rag/components/optimization/test_rag_optimization.py index 5ee6efe..6e082ab 100644 --- a/tests/unit/ai4rag/components/optimization/test_rag_optimization.py +++ b/tests/unit/ai4rag/components/optimization/test_rag_optimization.py @@ -12,9 +12,12 @@ DEFAULT_MAX_RAG_PATTERNS, MIN_MAX_RAG_PATTERNS_RANGE, SUPPORTED_OPTIMIZATION_METRICS, + _inject_language_instructions, _validate_optimization_settings, run_rag_optimization, ) +from ai4rag.search_space.src.parameter import Parameter +from ai4rag.search_space.src.search_space import AI4RAGSearchSpace # --------------------------------------------------------------------------- # Fixtures @@ -27,6 +30,15 @@ def mock_ogx_client() -> MagicMock: return MagicMock() +@pytest.fixture() +def foundation_model_stub() -> MagicMock: + """Return a MagicMock that acts as a foundation model with message attributes.""" + fm = MagicMock() + fm.system_message_text = "You are a helpful assistant." + fm.user_message_text = "Answer: {question}" + return fm + + # --------------------------------------------------------------------------- # _validate_optimization_settings # --------------------------------------------------------------------------- @@ -118,6 +130,99 @@ def test_extra_keys_preserved(self): assert result["metric"] == "faithfulness" +# --------------------------------------------------------------------------- +# _inject_language_instructions +# --------------------------------------------------------------------------- + + +class TestInjectLanguageInstructions: + """Tests for injecting language-specific instructions into foundation models.""" + + @staticmethod + def _make_search_space(foundation_models: list) -> AI4RAGSearchSpace: + """Build a minimal AI4RAGSearchSpace with required parameters.""" + dummy_embedding = MagicMock() + dummy_embedding.context_length = 8192 + return AI4RAGSearchSpace( + params=[ + Parameter("foundation_model", "C", values=foundation_models), + Parameter("embedding_model", "C", values=[dummy_embedding]), + ] + ) + + def test_injects_language_into_system_message(self, foundation_model_stub): + """The system message must be prefixed with the language instruction.""" + search_space = self._make_search_space([foundation_model_stub]) + + _inject_language_instructions(search_space, {"code": "ja", "name": "Japanese"}) + + fm = search_space["foundation_model"].values[0] + assert fm.system_message_text.startswith("You MUST respond in Japanese.") + + def test_injects_language_into_user_message(self, foundation_model_stub): + """The user message must be appended with the language instruction.""" + search_space = self._make_search_space([foundation_model_stub]) + + _inject_language_instructions(search_space, {"code": "ja", "name": "Japanese"}) + + fm = search_space["foundation_model"].values[0] + assert fm.user_message_text.endswith("You MUST respond in Japanese.") + + def test_english_code_skips_injection(self, foundation_model_stub): + """English language detection must not modify the messages.""" + original_sys = foundation_model_stub.system_message_text + original_usr = foundation_model_stub.user_message_text + + search_space = self._make_search_space([foundation_model_stub]) + + _inject_language_instructions(search_space, {"code": "en", "name": "English"}) + + fm = search_space["foundation_model"].values[0] + assert fm.system_message_text == original_sys + assert fm.user_message_text == original_usr + + def test_empty_name_skips_injection(self, foundation_model_stub): + """An empty language name must not inject any instruction.""" + original_sys = foundation_model_stub.system_message_text + + search_space = self._make_search_space([foundation_model_stub]) + + _inject_language_instructions(search_space, {"code": "ja", "name": ""}) + + fm = search_space["foundation_model"].values[0] + assert fm.system_message_text == original_sys + + def test_multiple_foundation_models(self): + """All foundation models in the search space must be updated.""" + fm1 = MagicMock() + fm1.system_message_text = "sys1" + fm1.user_message_text = "usr1" + + fm2 = MagicMock() + fm2.system_message_text = "sys2" + fm2.user_message_text = "usr2" + + search_space = self._make_search_space([fm1, fm2]) + + _inject_language_instructions(search_space, {"code": "ko", "name": "Korean"}) + + for fm in search_space["foundation_model"].values: + assert "You MUST respond in Korean." in fm.system_message_text + assert "You MUST respond in Korean." in fm.user_message_text + + def test_none_system_message_gets_instruction_prepended(self): + """A model with None system_message_text must still get the language instruction.""" + fm = MagicMock() + fm.system_message_text = None + fm.user_message_text = None + + search_space = self._make_search_space([fm]) + + _inject_language_instructions(search_space, {"code": "de", "name": "German"}) + + assert fm.system_message_text.startswith("You MUST respond in German.") + + # --------------------------------------------------------------------------- # run_rag_optimization -- input validation only # --------------------------------------------------------------------------- diff --git a/tests/unit/ai4rag/components/optimization/test_search_space_prep.py b/tests/unit/ai4rag/components/optimization/test_search_space_prep.py index 76d8e4e..752f05a 100644 --- a/tests/unit/ai4rag/components/optimization/test_search_space_prep.py +++ b/tests/unit/ai4rag/components/optimization/test_search_space_prep.py @@ -34,6 +34,7 @@ def simple_report() -> SearchSpaceReport: "foundation_model": ["model-a"], "embedding_model": ["emb-a"], }, + detected_language={"code": "ja", "name": "Japanese"}, ) @@ -93,6 +94,31 @@ def test_save_yaml_creates_file(self, simple_report, tmp_path: Path): data = yml.safe_load(out_file.read_text()) assert isinstance(data, dict) + def test_save_yaml_includes_detected_language(self, simple_report, tmp_path: Path): + """When detected_language is set, it must appear in the YAML output.""" + import yaml as yml + + out_file = tmp_path / "report.yaml" + simple_report.save_yaml(out_file) + + data = yml.safe_load(out_file.read_text()) + assert data["detected_language"] == {"code": "ja", "name": "Japanese"} + + def test_save_yaml_omits_language_when_none(self, tmp_path: Path): + """When detected_language is None, the key must not appear in output.""" + import yaml as yml + + report = SearchSpaceReport( + search_space={"chunk_size": [256]}, + selected_models={"foundation_model": []}, + detected_language=None, + ) + out_file = tmp_path / "report.yaml" + report.save_yaml(out_file) + + data = yml.safe_load(out_file.read_text()) + assert "detected_language" not in data + def test_save_yaml_creates_parent_directories(self, simple_report, tmp_path: Path): """save_yaml must create intermediate directories if they do not exist.""" out_file = tmp_path / "nested" / "dir" / "report.yaml" diff --git a/tests/unit/ai4rag/rag/foundation_models/test_base_model.py b/tests/unit/ai4rag/rag/foundation_models/test_base_model.py index 0e58efd..4a40739 100644 --- a/tests/unit/ai4rag/rag/foundation_models/test_base_model.py +++ b/tests/unit/ai4rag/rag/foundation_models/test_base_model.py @@ -54,24 +54,6 @@ def test_init(self, mock_client, model_params): assert model.client == mock_client assert model.model_id == "test-model-123" assert model.params == model_params - assert model.language_autodetect is False - - def test_language_autodetect_passed_to_default_user_message(self, mock_client, model_params, mocker): - """Test that language_autodetect is forwarded when resolving default user message text.""" - mock_get_user_message = mocker.patch( - "ai4rag.rag.foundation_models.base_model.get_user_message_text", - return_value="Question: {question}\nReferences: {reference_documents}", - ) - ConcreteFoundationModel( - client=mock_client, - model_id="meta-llama/llama-3-1-8b-instruct", - params=model_params, - language_autodetect=True, - ) - mock_get_user_message.assert_called_once_with( - model_name="meta-llama/llama-3-1-8b-instruct", - language_autodetect=True, - ) def test_repr(self, foundation_model): """Test __repr__ returns model_id.""" diff --git a/tests/unit/ai4rag/rag/foundation_models/test_ogx.py b/tests/unit/ai4rag/rag/foundation_models/test_ogx.py index e6670bf..bed995c 100644 --- a/tests/unit/ai4rag/rag/foundation_models/test_ogx.py +++ b/tests/unit/ai4rag/rag/foundation_models/test_ogx.py @@ -223,7 +223,7 @@ def test_user_message_text_default_when_none( context_template_text=valid_context_template, system_message_text=valid_system_message, ) - mock_get_user_message.assert_called_once_with(model_name="llama-3-70b", language_autodetect=False) + mock_get_user_message.assert_called_once_with(model_name="llama-3-70b") assert "Default user message" in model.user_message_text def test_context_template_text_custom(self, mock_ogx_client, valid_user_message_template, valid_system_message): diff --git a/tests/unit/ai4rag/rag/foundation_models/test_utils.py b/tests/unit/ai4rag/rag/foundation_models/test_utils.py index ec1fdf7..c6d3ac2 100644 --- a/tests/unit/ai4rag/rag/foundation_models/test_utils.py +++ b/tests/unit/ai4rag/rag/foundation_models/test_utils.py @@ -11,7 +11,6 @@ ) from ai4rag.search_space.src.model_props import ( CONTEXT_TEXT_PLACEHOLDER, - DOCUMENT_NUMBER_PLACEHOLDER, QUESTION_PLACEHOLDER, REFERENCE_DOCUMENTS_PLACEHOLDER, ) @@ -55,12 +54,6 @@ def test_context_template_missing_placeholder(self): assert "Incorrect number of placeholders" in str(exc_info.value) assert "expected 1 but got 0" in str(exc_info.value) - def test_context_template_with_doc_number(self): - """Test that context template may include optional doc_number placeholder.""" - template = f"Document {{{DOCUMENT_NUMBER_PLACEHOLDER}}}:\n{{{CONTEXT_TEXT_PLACEHOLDER}}}\n" - result = _validate_prompt_templates_placeholders(template, "context_template_text") - assert result == template - def test_user_message_missing_one_placeholder(self): """Test that user message with only one placeholder raises ValueError.""" template = f"Question: {{{QUESTION_PLACEHOLDER}}}" @@ -232,13 +225,6 @@ def test_context_template_missing_placeholder(self): assert "Incorrect number of placeholders" in str(exc_info.value) assert "expected 1 but got 0" in str(exc_info.value) - def test_context_template_with_doc_number(self): - """Test that context template may include optional doc_number placeholder.""" - validator = RAGPromptTemplateString("context_template_text") - template = f"Document {{{DOCUMENT_NUMBER_PLACEHOLDER}}}:\n{{{CONTEXT_TEXT_PLACEHOLDER}}}\n" - result = validator.validate(None, template) - assert result == template - def test_user_message_missing_one_placeholder(self): """Test that user message with only one placeholder raises ConstraintsValidationError.""" validator = RAGPromptTemplateString("user_message_text") diff --git a/tests/unit/ai4rag/rag/template/test_simple_rag_template.py b/tests/unit/ai4rag/rag/template/test_simple_rag_template.py index 56c6ec2..1a7eedb 100644 --- a/tests/unit/ai4rag/rag/template/test_simple_rag_template.py +++ b/tests/unit/ai4rag/rag/template/test_simple_rag_template.py @@ -445,24 +445,6 @@ def test_generate_builds_context_from_retrieved_documents( assert "Document: Relevant document 2" in user_message assert "What is AI?" in user_message - def test_generate_numbers_documents_when_template_includes_doc_number( - self, - mock_foundation_model, - mock_retriever, - ): - """Test that doc_number is passed when the context template includes it.""" - mock_foundation_model.context_template_text = "Document {doc_number}:\n{document}\n" - rag = SimpleRAG( - foundation_model=mock_foundation_model, - retriever=mock_retriever, - ) - - rag.generate("What is AI?") - - user_message = mock_foundation_model.chat.call_args.kwargs["messages"][1]["content"] - assert "Document 1:\nRelevant document 1" in user_message - assert "Document 2:\nRelevant document 2" in user_message - def test_generate_calls_foundation_model_chat( self, mock_foundation_model, diff --git a/tests/unit/ai4rag/search_space/src/test_model_props.py b/tests/unit/ai4rag/search_space/src/test_model_props.py deleted file mode 100644 index e5d5c15..0000000 --- a/tests/unit/ai4rag/search_space/src/test_model_props.py +++ /dev/null @@ -1,91 +0,0 @@ -# ----------------------------------------------------------------------------- -# Copyright IBM Corp. 2026 -# SPDX-License-Identifier: Apache-2.0 -# ----------------------------------------------------------------------------- -"""Tests for default RAG prompt templates.""" - -import pytest - -from ai4rag.search_space.src.model_props import ( - DOCUMENT_NUMBER_PLACEHOLDER, - get_context_template_text, - get_system_message_text, - get_user_message_text, -) - - -@pytest.mark.parametrize( - "model_name", - [ - "meta-llama/llama-3-1-8b-instruct", - "ibm/granite-3-8b-instruct", - "mistralai/mistral-large", - "openai/gpt-oss-120b", - "unknown-model", - ], -) -def test_user_message_includes_grounding_and_citations(model_name: str): - user_message = get_user_message_text(model_name) - assert "Answer ONLY" in user_message - assert "MUST cite sources" in user_message - assert "{reference_documents}" in user_message - assert "{question}" in user_message - - -@pytest.mark.parametrize( - "model_name", - [ - "meta-llama/llama-3-1-8b-instruct", - "ibm/granite-3-8b-instruct", - "mistralai/mistral-large", - "openai/gpt-oss-120b", - "vllm-inference-gpu-llama/redhataillama-31-8b-instruct", - ], -) -def test_system_message_includes_rag_prefix(model_name: str): - system_message = get_system_message_text(model_name) - assert "retrieval-augmented assistant" in system_message - assert "ONLY the provided documents" in system_message - - -@pytest.mark.parametrize( - "model_name", - [ - "meta-llama/llama-3-1-8b-instruct", - "ibm/granite-3-8b-instruct", - "mistralai/mistral-large", - "openai/gpt-oss-120b", - "unknown-model", - ], -) -def test_context_template_numbers_documents(model_name: str): - context_template = get_context_template_text(model_name) - assert f"{{{DOCUMENT_NUMBER_PLACEHOLDER}}}" in context_template - assert "{document}" in context_template - - -def test_language_autodetect_defaults_to_english_only(): - user_message = get_user_message_text("meta-llama/llama-3-1-8b-instruct") - assert "You MUST write your entire answer in English only" in user_message - assert "Do NOT use any other language" in user_message - - -def test_language_autodetect_enabled_uses_strong_question_language_instruction(): - user_message = get_user_message_text("meta-llama/llama-3-1-8b-instruct", language_autodetect=True) - assert "You MUST write your entire answer in the same language as the question" in user_message - assert "Do NOT respond in any other language" in user_message - - -@pytest.mark.parametrize( - "model_name", - [ - "meta-llama/llama-3-1-8b-instruct", - "ibm/granite-3-8b-instruct", - "mistralai/mistral-large", - "openai/gpt-oss-120b", - "unknown-model", - ], -) -def test_user_message_includes_consistent_answer_length(model_name: str): - user_message = get_user_message_text(model_name) - assert "Answer (max 150 words, with citations):" in user_message From 94086d3778bdd0503c57dd20d1d0fb0c36da6618 Mon Sep 17 00:00:00 2001 From: Lukasz Cmielowski Date: Mon, 29 Jun 2026 13:10:59 +0200 Subject: [PATCH 09/16] fix(assets): merge HPO user prompt rules into Responses export system (#83) * fix(assets): merge HPO user prompt rules into Responses export system input Map Chat Completions prompts to responses_template without losing HPO-tuned rules or duplicating OGX file_search runtime text. Add build_responses_system_input() to merge non-redundant static supplements from user_message_text (RAG rules, answer length, language policy) into input[role=system], while stripping {reference_documents}, {question}, and citation instructions that OGX injects via annotation_prompt_params. Wire build_pattern_json() to use the merged system text (detected_language is still applied on generation before export). Add unit tests for export parity across all default model families and legacy prompt patterns. RHOAIENG-71231 Signed-off-by: Lukasz Cmielowski Assisted-by: Cursor * black formatting Signed-off-by: Lukasz Cmielowski Assisted-by: Cursor * ruff fix Signed-off-by: Lukasz Cmielowski Assisted-by: Cursor * fix(assets): add fallback for empty system input in Responses export Signed-off-by: Lukasz Cmielowski Assisted-by: Cursor * apply the feedback from Mateusz fix(assets): address PR review for Responses export system input merge Reorganize OGX runtime phrase lists into citation, grounding, and file-search groups with derived combined views. Tighten grounding deduplication so persona-only system prompts no longer suppress valid user supplements (e.g. Granite RAG block). Simplify OGX runtime stripping to a single pass, collapse whitespace after phrase removal, and remove unused citation helpers. Add unit tests for hybrid alpha=1.0 fallthrough, partial OGX sentence removal, explicit vs persona grounding detection, export-parity system input, and required generation fields. Signed-off-by: Lukasz Cmielowski Assisted-by: Cursor * refactor(assets): address staff review on Responses export prompt mapping efactor(assets): address staff review on Responses export prompt mapping Remove unreachable citation dedup branch, simplify citation-line detection to use _CITATION_PREFIXES/_CITATION_SUBSTRINGS directly, and derive _system_has_grounding_policy from _GROUNDING_PREFIXES for a single source of truth. Rename _USER_RAG_SCAFFOLD_PREFIXES to _USER_RAG_GROUNDING_PREFIXES and document _join_answer_scaffold_blocks scope. Add direct tests for _is_placeholder_only_export and grounding-prefix coverage; remove duplicate empty-input test. Signed-off-by: Lukasz Cmielowski Assisted-by: Cursor --------- Signed-off-by: Lukasz Cmielowski --- .../assets_generator/pattern_builder.py | 438 +++++++++++++++++- .../assets_generator/test_pattern_builder.py | 320 ++++++++++++- 2 files changed, 747 insertions(+), 11 deletions(-) diff --git a/ai4rag/components/assets_generator/pattern_builder.py b/ai4rag/components/assets_generator/pattern_builder.py index 6babc84..e65f37f 100644 --- a/ai4rag/components/assets_generator/pattern_builder.py +++ b/ai4rag/components/assets_generator/pattern_builder.py @@ -2,11 +2,421 @@ # Copyright IBM Corp. 2026 # SPDX-License-Identifier: Apache-2.0 # ----------------------------------------------------------------------------- +"""Map HPO Chat Completions prompts to exported Responses ``input[system]`` text. + +OGX-owned phrases are defined below and must stay aligned with +``benchmarking/rag/config.yaml`` (``file_search_params``, ``context_prompt_params``, +``annotation_prompt_params``). If OGX injection strings change, update the lists here. +""" + +import re + +_USER_QUERY_PLACEHOLDER = "" +_EMPTY_SYSTEM_FALLBACK = "You are a helpful assistant." +_EXPORT_SLOT_MARKERS = ("{reference_documents}", "{question}", "{multilingual_support}") + +# Suffix lines after ``{reference_documents}``: drop structural wrappers (e.g. ``[End]``). +_DOCUMENT_SLOT_MARKERS = frozenset({"[Document]", "[End]", "Documents:", "Context:"}) + +# ============================================================================ +# OGX Runtime Injection Strings +# ============================================================================ +# These phrases are injected by OGX at file_search runtime via +# benchmarking/rag/config.yaml (file_search_params, context_prompt_params, +# annotation_prompt_params). HPO export must NOT duplicate them in +# responses_template.input[system]. +# +# If OGX changes injection strings in config.yaml, update these lists. +# ============================================================================ + +# Citation-related phrases +_CITATION_PREFIXES = ( + "You MUST cite sources", + "Cite sources immediately", +) +_CITATION_SUBSTRINGS = ( + "[1], [2]", + "<|file-id|>", + "cite as <|", + "file citations", + "document numbers for every factual claim", +) +_HPO_CITATION_INSTRUCTION = ( + "You MUST cite sources using [1], [2], etc. matching the document numbers for every factual claim." +) +_HPO_CITATION_FRAGMENTS = ( + _HPO_CITATION_INSTRUCTION, + "You MUST cite sources using [1], [2], etc.", + "You MUST cite sources using [1], [2].", +) + +# Grounding/retrieval-related phrases +_GROUNDING_PREFIXES = ( + "Answer ONLY using information from the documents", + "Answer ONLY using information from documents retrieved", + "Answer using ONLY the provided documents", + "Answer using ONLY information from documents", + "Do not use outside knowledge", + "If the retrieved documents do not contain", + "If the documents do not contain", +) +_GROUNDING_SUBSTRINGS = ( + "documents below", + "retrieved via file search", + "retrieved to help answer the user", + "supporting information only in answering", +) +_SYSTEM_GROUNDING_PHRASES = ( + "Answer using ONLY the provided documents.", + "Answer using ONLY information from documents retrieved via file search.", +) + +# File search tool markers +_FILE_SEARCH_MARKERS = ( + "file_search tool found", + "BEGIN of file_search tool results", + "END of file_search tool results", + "The above results were retrieved to help answer", + "Use them as supporting information only", + "Do not add extra punctuation. Use only the file IDs", +) + +# User template duplicate detection (pass 1 filtering) +_USER_GROUNDING_SKIP_PREFIXES = ( + "Answer ONLY using information from the documents below", + "Do not use outside knowledge", + "If the documents do not contain the answer", +) +_USER_RAG_GROUNDING_PREFIXES = ( + "You are a specialized Retrieval Augmented Generation", + "Prioritize correctness and ensure your response is grounded", +) + +# Document and question slot markers +_DOCUMENT_LABELS = ("Documents:", "Context:", "[Document]") +_QUESTION_PREFIXES = ("Question:", "Q:", "[conversation]:") +_LEGACY_DOCUMENT_MARKERS = ("Documents:\n", "Context:\n", "[Document]\n") + +# Combined line prefixes for sentence-level filtering +_OGX_DUPLICATIVE_LINE_PREFIXES = _CITATION_PREFIXES + _GROUNDING_PREFIXES + _FILE_SEARCH_MARKERS + +# Combined substrings for partial-match filtering +_OGX_DUPLICATIVE_SUBSTRINGS = _CITATION_SUBSTRINGS + _GROUNDING_SUBSTRINGS + + +def _collapse_whitespace(text: str) -> str: + """Collapse repeated interior spaces after phrase removal.""" + return re.sub(r" +", " ", text).strip() + + +def _sentence_is_ogx_duplicative(sentence: str) -> bool: + """Return whether a sentence duplicates OGX file_search runtime injection.""" + stripped = sentence.strip().rstrip(".") + if not stripped: + return True + if any(stripped.startswith(prefix.rstrip(".")) for prefix in _OGX_DUPLICATIVE_LINE_PREFIXES): + return True + normalized = stripped.lower() + return any(fragment.lower() in normalized for fragment in _OGX_DUPLICATIVE_SUBSTRINGS) + + +def _is_citation_related_line(line: str) -> bool: + """Return whether an entire line should be dropped as citation guidance.""" + stripped = line.strip() + if not stripped: + return False + lower = stripped.lower() + if any(stripped.startswith(prefix) for prefix in _CITATION_PREFIXES): + return True + if any(fragment.lower() in lower for fragment in _HPO_CITATION_FRAGMENTS): + return True + return any(sub.lower() in lower for sub in _CITATION_SUBSTRINGS) + + +def _filter_ogx_duplicative_sentences(line: str) -> str: + """Remove OGX-duplicative sentences while keeping persona or policy sentences.""" + stripped = line.strip() + if not stripped or _is_citation_related_line(stripped): + return "" + + # Split on ". " only — avoids breaking abbreviations such as "i.e.," + parts = [part.strip() for part in stripped.split(". ") if part.strip()] + if len(parts) <= 1: + if _sentence_is_ogx_duplicative(stripped.rstrip(".")): + return "" + return stripped + + kept = [part.rstrip(".") for part in parts if not _sentence_is_ogx_duplicative(part.rstrip("."))] + if not kept: + return "" + + result = ". ".join(kept) + if stripped.endswith("."): + result += "." + return result + + +def _normalize_answer_scaffold(line: str) -> str: + """Drop citation hints from answer scaffolds; OGX owns citation via annotations.""" + normalized = line.replace(", with citations", "").replace("with citations", "") + return _collapse_whitespace(normalized) + + +def _strip_ogx_runtime_instructions(text: str) -> str: + """Remove text that OGX injects via file_search config at inference time.""" + if not text.strip(): + return "" + + for phrase in _SYSTEM_GROUNDING_PHRASES: + text = text.replace(phrase, "").replace(phrase.rstrip("."), "") + text = _collapse_whitespace(text) + + lines: list[str] = [] + for line in text.splitlines(): + stripped = line.strip() + if not stripped: + if lines and lines[-1] != "": + lines.append("") + continue + if _is_citation_related_line(stripped): + continue + + cleaned = _filter_ogx_duplicative_sentences(stripped) + for fragment in _HPO_CITATION_FRAGMENTS: + if fragment in cleaned: + cleaned = cleaned.replace(fragment, "").strip() + break + cleaned = _normalize_answer_scaffold(cleaned) + if cleaned: + lines.append(cleaned) + + result = "\n".join(lines) + while "\n\n\n" in result: + result = result.replace("\n\n\n", "\n\n") + return result.strip() + + +def _join_answer_scaffold_blocks(lines: list[str]) -> str: + """Group lines into paragraph blocks, starting a new block when an answer-scaffold line appears. + + Scaffold lines are specifically in the form ``Answer (...)`` — e.g. + ``"Answer (max 150 words):"`` — as produced by HPO prompt templates. + Other leading text such as ``"Answer:"`` or ``"Response:"`` does NOT + trigger a new block. + """ + if not lines: + return "" + + blocks: list[str] = [] + current_block: list[str] = [] + for line in lines: + if line.startswith("Answer (") and current_block: + blocks.append("\n".join(current_block)) + current_block = [line] + else: + current_block.append(line) + if current_block: + blocks.append("\n".join(current_block)) + return "\n\n".join(blocks) + + +def _should_skip_redundant_user_line(stripped: str, system_has_grounding: bool) -> bool: + """Return whether a user-template line duplicates system policy for export.""" + if _is_citation_related_line(stripped): + return True + return system_has_grounding and any( + stripped.startswith(prefix) for prefix in _GROUNDING_PREFIXES + _USER_RAG_GROUNDING_PREFIXES + ) + + +def _should_skip_user_export_line(stripped: str) -> bool: + """Return whether a merged user line is OGX-owned and must not be exported.""" + if any(stripped.startswith(prefix) for prefix in _USER_GROUNDING_SKIP_PREFIXES): + return True + return _is_citation_related_line(stripped) + + +def _strip_document_slot_prefix(prefix: str) -> str: + """Remove structural labels that wrap the reference-documents slot.""" + for label in _DOCUMENT_LABELS: + if prefix == label: + return "" + if prefix.endswith(label): + return prefix[: -len(label)].strip() + return prefix + + +def _extract_static_suffix_line(stripped: str) -> str | None: + """Return static instruction text from one post-documents template line.""" + if not stripped or stripped == ":" or stripped in _DOCUMENT_SLOT_MARKERS: + return None + if "{question}" in stripped: + without_question = stripped.replace("{question}", "").strip() + for question_prefix in _QUESTION_PREFIXES: + if without_question.startswith(question_prefix): + without_question = without_question[len(question_prefix) :].strip() + without_question = without_question.lstrip(":.").strip() + return without_question or None + if stripped.startswith(_QUESTION_PREFIXES): + return None + if "{multilingual_support}" in stripped: + return None + return stripped + + +def _extract_static_user_from_reference_slot(text: str) -> str: + """Extract static instructions from a template that contains ``{reference_documents}``.""" + before, after = text.split("{reference_documents}", 1) + parts: list[str] = [] + prefix = _strip_document_slot_prefix(before.strip()) + if prefix: + parts.append(prefix) + + suffix_lines = [ + line_text + for line_text in (_extract_static_suffix_line(line.strip()) for line in after.splitlines()) + if line_text + ] + if suffix_lines: + parts.append("\n".join(suffix_lines)) + return "\n\n".join(parts).strip() + + +def _extract_static_user_without_reference_slot(text: str) -> str: + """Extract static instructions from legacy templates without an explicit doc slot.""" + doc_idx = len(text) + for marker in _LEGACY_DOCUMENT_MARKERS: + idx = text.find(marker) + if idx != -1: + doc_idx = min(doc_idx, idx) + return text[:doc_idx].strip() + + +def _system_has_grounding_policy(system: str) -> bool: + """Return whether the system prompt already states an explicit document-only grounding rule. + + Uses the same prefix list as sentence-level filtering so that adding a new + OGX phrase to ``_GROUNDING_PREFIXES`` automatically covers system detection too. + Does NOT match descriptive personas like "retrieval-augmented assistant" without + an explicit grounding constraint. + """ + normalized = system.lower() + return any(prefix.lower() in normalized for prefix in _GROUNDING_PREFIXES) + + +def _filter_static_user_for_responses(system: str, static_user: str) -> str: + """Drop user-template lines that duplicate system policy for Responses export. + + Pass 1 of 2: compare against ``original_system`` (author intent before OGX + stripping). Removes user lines that repeat grounding or citation policy already + present in the HPO system prompt. + """ + if not static_user.strip(): + return "" + + system_has_grounding = _system_has_grounding_policy(system) + + filtered_lines: list[str] = [] + for line in static_user.splitlines(): + stripped = line.strip() + if not stripped or _should_skip_redundant_user_line(stripped, system_has_grounding): + continue + filtered_lines.append(stripped) + + return _join_answer_scaffold_blocks(filtered_lines) + + +def _adapt_system_for_responses_export(system: str) -> str: + """Drop OGX-runtime retrieval/citation text from the HPO system prompt.""" + return _strip_ogx_runtime_instructions(system) + + +def _adapt_static_user_for_responses_export(static_user: str) -> str: + """Drop merged user supplements that OGX injects at file_search runtime. + + Pass 2 of 2: strip OGX-runtime phrases from user lines that survived pass 1. + """ + if not static_user.strip(): + return "" + + adapted_lines: list[str] = [] + for line in static_user.splitlines(): + stripped = line.strip() + if not stripped or _should_skip_user_export_line(stripped): + continue + cleaned = _strip_ogx_runtime_instructions(stripped) + if cleaned: + adapted_lines.append(cleaned) + + return _join_answer_scaffold_blocks(adapted_lines) + + +def _extract_static_user_instructions(user_message_text: str) -> str: + """Return static instruction text from a HPO user template. + + Strips runtime slots (retrieved documents, question) that Responses API + supplies via ``file_search`` and the user ``input`` message respectively. + """ + if not user_message_text: + return "" + + text = str(user_message_text).strip() + if "{reference_documents}" in text: + return _extract_static_user_from_reference_slot(text) + + prefix = _extract_static_user_without_reference_slot(text) + return prefix + + +def _is_placeholder_only_export(text: str) -> bool: + """Return whether export text contains only unresolved HPO template slots.""" + cleaned = text.strip() + if not cleaned: + return True + for marker in _EXPORT_SLOT_MARKERS: + cleaned = cleaned.replace(marker, "") + return not cleaned.strip() + + +def build_responses_system_input(generation: dict) -> str: + """Build Responses API system input aligned with HPO chat/completion prompts. + + HPO sends ``system_message_text`` plus a formatted ``user_message_text`` + (rules, documents, question). Responses uses ``file_search`` for documents + and a separate user message for the question. Non-redundant supplements + from the user template are merged into export; retrieval framing, chunk + presentation, and citation instructions owned by OGX ``config.yaml`` are + stripped rather than rephrased into the exported system input. + """ + original_system = (generation.get("system_message_text") or "").strip() + exported_system = _adapt_system_for_responses_export(original_system) + user_template = generation.get("user_message_text") or "" + + # Pass 1: dedupe vs original_system; pass 2: strip OGX-owned user supplements. + static_user = _adapt_static_user_for_responses_export( + _filter_static_user_for_responses( + original_system, + _extract_static_user_instructions(user_template), + ), + ) + + if exported_system and static_user: + result = f"{exported_system}\n\n{static_user}" + else: + result = exported_system or static_user + + # Fallback for completely empty patterns (rare edge case) + if not result or not result.strip() or _is_placeholder_only_export(result): + return _EMPTY_SYSTEM_FALLBACK + + return result + + def build_pattern_json( pattern: dict, detected_language: dict | None = None, ) -> dict: - """Update pattern information with detected language and responses template. + """Update pattern information with responses template. Parameters ---------- @@ -14,29 +424,39 @@ def build_pattern_json( A single evaluation result object carrying ``indexing_params``, ``rag_params``, ``pattern_name``, ``collection``, etc. detected_language : dict | None, default=None - Language detection result (``{"code": "...", "name": "..."}``). + Language detection result (``{"code": "...", "name": "..."}``) stored + on ``generation`` before building the Responses export. + + Notes + ----- + ``pattern["settings"]["generation"]`` must include ``model_id``, + ``temperature``, ``max_completion_tokens``, ``system_message_text``, and + ``user_message_text`` (as produced by the experiment payload). Returns ------- dict Pattern definition suitable for JSON serialisation. """ + generation = pattern["settings"]["generation"] if detected_language: - pattern["settings"]["generation"]["detected_language"] = detected_language + generation["detected_language"] = detected_language + + system_input = build_responses_system_input(generation) pattern["settings"]["responses_template"] = { - "model": pattern["settings"]["generation"]["model_id"], + "model": generation["model_id"], "stream": False, "store": False, "input": [ { - "content": [{"text": pattern["settings"]["generation"]["system_message_text"], "type": "input_text"}], + "content": [{"text": system_input, "type": "input_text"}], "role": "system", }, - {"content": [{"text": "", "type": "input_text"}], "role": "user"}, + {"content": [{"text": _USER_QUERY_PLACEHOLDER, "type": "input_text"}], "role": "user"}, ], - "max_output_tokens": pattern["settings"]["generation"]["max_completion_tokens"], - "temperature": pattern["settings"]["generation"]["temperature"], + "max_output_tokens": generation["max_completion_tokens"], + "temperature": generation["temperature"], "tool_choice": {"mode": "required", "tools": [{}], "type": "file_search"}, "tools": [ { @@ -60,13 +480,13 @@ def build_pattern_json( "impact_factor": ranker_k, } elif search_mode == "hybrid" and ranker_strategy == "weighted" and ranker_alpha is not None and ranker_alpha != 1: + # ``ranker_alpha == 1.0`` intentionally falls through to ``else`` (semantic-only default). pattern["settings"]["responses_template"]["tools"][0]["ranking_options"] = { "ranker": "weighted", "alpha": ranker_alpha, } else: pattern["settings"]["responses_template"]["tools"][0]["ranking_options"] = { - # simulate semantic-only search "ranker": "weighted", "alpha": 1.0, } diff --git a/tests/unit/ai4rag/assets_generator/test_pattern_builder.py b/tests/unit/ai4rag/assets_generator/test_pattern_builder.py index c4a0ee4..b514629 100644 --- a/tests/unit/ai4rag/assets_generator/test_pattern_builder.py +++ b/tests/unit/ai4rag/assets_generator/test_pattern_builder.py @@ -9,6 +9,12 @@ import pytest from ai4rag.components.assets_generator import build_pattern_json +from ai4rag.components.assets_generator.pattern_builder import ( + _is_placeholder_only_export, + _normalize_answer_scaffold, + build_responses_system_input, +) +from ai4rag.search_space.src.model_props import get_system_message_text, get_user_message_text # --------------------------------------------------------------------------- # Helpers @@ -58,6 +64,22 @@ def _make_pattern(**overrides) -> dict: return base +@pytest.mark.parametrize( + ("text", "expected"), + [ + ("", True), + (" ", True), + ("{reference_documents}", True), + ("{reference_documents}\n{question}", True), + ("foo {reference_documents}", False), + ("You are a helpful assistant.", False), + ], +) +def test_is_placeholder_only_export(text: str, expected: bool): + """Placeholder-only export text must trigger the empty-input fallback path.""" + assert _is_placeholder_only_export(text) == expected + + # --------------------------------------------------------------------------- # build_pattern_json -- responses_template generation # --------------------------------------------------------------------------- @@ -72,12 +94,15 @@ def test_adds_responses_template(self): result = build_pattern_json(pattern) rt = result["settings"]["responses_template"] + generation = result["settings"]["generation"] + expected_system = build_responses_system_input(generation) + assert rt["model"] == "ibm/granite-3-8b-instruct" assert rt["stream"] is False assert rt["store"] is False assert rt["input"] == [ { - "content": [{"text": "Answer based on context only.", "type": "input_text"}], + "content": [{"text": expected_system, "type": "input_text"}], "role": "system", }, {"content": [{"text": "", "type": "input_text"}], "role": "user"}, @@ -121,6 +146,7 @@ def test_hybrid_rrf_ranking_options(self): ro = pattern["settings"]["responses_template"]["tools"][0]["ranking_options"] assert ro == {"ranker": "rrf", "impact_factor": 60} + assert pattern["settings"]["responses_template"]["tools"][0]["max_num_results"] == 5 def test_hybrid_weighted_ranking_options(self): """Hybrid search with weighted ranker must set ranker and alpha in ranking_options.""" @@ -133,9 +159,10 @@ def test_hybrid_weighted_ranking_options(self): ro = pattern["settings"]["responses_template"]["tools"][0]["ranking_options"] assert ro == {"ranker": "weighted", "alpha": 0.7} + assert pattern["settings"]["responses_template"]["tools"][0]["max_num_results"] == 5 def test_simple_retrieval_default_ranking_options(self): - """Simple retrieval must have default weights in ranking_options.""" + """Vector-only search simulates semantic retrieval via weighted ranker alpha=1.0.""" pattern = _make_pattern() build_pattern_json(pattern) @@ -143,6 +170,295 @@ def test_simple_retrieval_default_ranking_options(self): assert ro == {"ranker": "weighted", "alpha": 1.0} assert pattern["settings"]["responses_template"]["tools"][0]["max_num_results"] == 5 + def test_hybrid_weighted_alpha_one_uses_default_ranking(self): + """Hybrid weighted with alpha=1.0 uses the default semantic-only simulation branch.""" + pattern = _make_pattern() + pattern["settings"]["retrieval"]["search_mode"] = "hybrid" + pattern["settings"]["retrieval"]["ranker_strategy"] = "weighted" + pattern["settings"]["retrieval"]["ranker_alpha"] = 1.0 + + build_pattern_json(pattern) + + ro = pattern["settings"]["responses_template"]["tools"][0]["ranking_options"] + assert ro == {"ranker": "weighted", "alpha": 1.0} + + def test_export_system_input_merges_non_redundant_user_rules(self): + """Legacy user supplements merge; redundant grounding and citations are omitted.""" + pattern = _make_pattern() + pattern["settings"]["generation"][ + "system_message_text" + ] = "You are a retrieval-augmented assistant. Answer using ONLY the provided documents." + pattern["settings"]["generation"]["user_message_text"] = ( + "Answer ONLY using information from the documents below. " + "Do not use outside knowledge.\n" + "You MUST cite sources using [1], [2], etc.\n\n" + "Documents:\n{reference_documents}\n\n" + "Question: {question}\n\n" + "Answer (max 150 words, with citations):\n" + "You MUST write your entire answer in English only." + ) + + build_pattern_json(pattern) + + system_text = pattern["settings"]["responses_template"]["input"][0]["content"][0]["text"] + assert "retrieval-augmented assistant" in system_text + assert "retrieved via file search" not in system_text + assert "provided documents" not in system_text.lower() + assert "Answer ONLY using information from the documents below" not in system_text + assert "must cite sources" not in system_text.lower() + assert "file citations" not in system_text.lower() + assert "max 150 words" in system_text + assert "with citations" not in system_text.lower() + assert "English only" in system_text + assert "{reference_documents}" not in system_text + assert "{question}" not in system_text + + def test_export_system_input_skips_duplicate_citation_and_keeps_answer_scaffold(self): + """Citation lines are stripped; answer scaffold and language policy still merge.""" + pattern = _make_pattern() + pattern["settings"]["generation"][ + "system_message_text" + ] = "You are a retrieval-augmented assistant. You MUST cite sources using [1], [2]." + pattern["settings"]["generation"]["user_message_text"] = ( + "You MUST cite sources using [1], [2], etc.\n\n" + "Documents:\n{reference_documents}\n\n" + "Question: {question}\n\n" + "Answer (max 150 words, with citations):\n" + "You MUST write your entire answer in English only." + ) + + build_pattern_json(pattern) + + system_text = pattern["settings"]["responses_template"]["input"][0]["content"][0]["text"] + assert "must cite sources" not in system_text.lower() + assert "max 150 words" in system_text + assert "English only" in system_text + + def test_build_responses_system_input_legacy_user_prefix(self): + """Legacy grounding and citation lines are omitted; persona supplements are kept.""" + generation = { + "system_message_text": "Short system prefix.", + "user_message_text": ( + "Answer ONLY using information from the documents below.\n" + "You MUST cite sources using [1], [2].\n\n" + "Context: {reference_documents}\n\n" + "Question: {question}\n" + ), + } + system_input = build_responses_system_input(generation) + assert system_input == "Short system prefix." + assert "retrieved via file search" not in system_input + assert "must cite sources" not in system_input.lower() + assert "documents below" not in system_input + + def test_build_pattern_json_uses_export_parity_system_input(self): + """build_pattern_json must use build_responses_system_input(), not raw system text.""" + model_id = "ibm/granite-3-8b-instruct" + expected = build_responses_system_input( + { + "system_message_text": get_system_message_text(model_id), + "user_message_text": get_user_message_text(model_id, language_autodetect=False), + } + ) + + pattern = _make_pattern() + pattern["settings"]["generation"]["model_id"] = model_id + pattern["settings"]["generation"]["system_message_text"] = get_system_message_text(model_id) + pattern["settings"]["generation"]["user_message_text"] = get_user_message_text( + model_id, language_autodetect=False + ) + + build_pattern_json(pattern) + + actual = pattern["settings"]["responses_template"]["input"][0]["content"][0]["text"] + assert actual == expected + assert actual != pattern["settings"]["generation"]["system_message_text"] + assert "Granite Chat" in actual + assert "Retrieval Augmented Generation" in actual + assert "i.e.," in actual + assert "English only" in actual + + @pytest.mark.parametrize( + "model_id", + [ + "unknown-model", + "ibm/granite-3-8b-instruct", + "meta-llama/llama-3-1-8b-instruct", + "mistralai/mistral-large", + "openai/gpt-oss-120b", + ], + ) + def test_export_omits_ogx_duplicative_prompt_text(self, model_id: str): + """Export must not duplicate citation/retrieval text that OGX injects at file_search runtime.""" + generation = { + "system_message_text": get_system_message_text(model_id), + "user_message_text": get_user_message_text(model_id, language_autodetect=False), + } + system_text = build_responses_system_input(generation) + + assert "[1], [2]" not in system_text + assert "must cite sources" not in system_text.lower() + assert "file citations" not in system_text.lower() + assert "documents below" not in system_text + assert "retrieved via file search" not in system_text + assert "retrieved to help answer" not in system_text.lower() + assert "<|file-id|>" not in system_text + assert "cite sources immediately" not in system_text.lower() + assert "supporting information only" not in system_text.lower() + assert "{reference_documents}" not in system_text + assert "{question}" not in system_text + assert "[End]" not in system_text + + def test_export_omits_ogx_config_yaml_instruction_text(self): + """Export must not contain verbatim OGX annotation/context template phrases.""" + generation = { + "system_message_text": ( + "You are a retrieval-augmented assistant. " + "Cite sources immediately at the end of sentences before punctuation." + ), + "user_message_text": ( + "The above results were retrieved to help answer the user's query. " + "Use them as supporting information only in answering this query.\n" + "Documents:\n{reference_documents}\n\n" + "Question: {question}\n" + ), + } + system_text = build_responses_system_input(generation) + assert system_text == "You are a retrieval-augmented assistant." + + @pytest.mark.parametrize( + ("model_id", "expected_fragment"), + [ + ("meta-llama/llama-3-1-8b-instruct", "150 words"), + ("mistralai/mistral-large", "titles of documents"), + ("openai/gpt-oss-120b", "Answer Length: concise"), + ], + ) + def test_export_merges_family_specific_user_rules(self, model_id: str, expected_fragment: str): + """Each model family must preserve user-only rules in exported system input.""" + generation = { + "system_message_text": get_system_message_text(model_id), + "user_message_text": get_user_message_text(model_id, language_autodetect=False), + } + system_text = build_responses_system_input(generation) + assert expected_fragment in system_text + + def test_build_responses_system_input_handles_empty_inputs(self): + """When both system and user are empty or contain only placeholders, return fallback.""" + # Case 1: Completely empty + generation = { + "system_message_text": "", + "user_message_text": "", + } + system_text = build_responses_system_input(generation) + assert system_text == "You are a helpful assistant." + + # Case 2: Only placeholders in user template + generation = { + "system_message_text": "", + "user_message_text": "{reference_documents}\n{question}", + } + system_text = build_responses_system_input(generation) + assert system_text == "You are a helpful assistant." + + # Case 4: Unresolved question slot only (cookbook-style minimal user template) + generation = { + "system_message_text": "", + "user_message_text": "{question}", + } + system_text = build_responses_system_input(generation) + assert system_text == "You are a helpful assistant." + + # Case 3: Only OGX-duplicative content that gets stripped + generation = { + "system_message_text": "Answer using ONLY the provided documents.", + "user_message_text": ( + "Answer ONLY using information from the documents below.\n" + "You MUST cite sources using [1], [2].\n" + "{reference_documents}\n{question}" + ), + } + system_text = build_responses_system_input(generation) + assert system_text == "You are a helpful assistant." + + def test_strip_ogx_runtime_partial_sentence_removal(self): + """OGX sentences in a multi-sentence system prompt are removed; others are kept.""" + generation = { + "system_message_text": ( + "You are an expert assistant. " "Answer using ONLY the provided documents. " "Be concise." + ), + "user_message_text": "", + } + result = build_responses_system_input(generation) + assert "You are an expert assistant" in result + assert "Be concise" in result + assert "provided documents" not in result.lower() + + def test_user_grounding_merges_when_system_is_persona_only(self): + """Persona-only system must not suppress non-OGX user supplements (e.g. RAG block).""" + generation = { + "system_message_text": "You are a retrieval-augmented assistant. Use your best judgment.", + "user_message_text": ( + "You are a specialized Retrieval Augmented Generation (RAG) assistant. " + "Prioritize correctness and ensure your response is grounded in the documents.\n" + "{reference_documents}\n{question}" + ), + } + result = build_responses_system_input(generation) + assert "retrieval-augmented assistant" in result + assert "specialized Retrieval Augmented Generation" in result + + def test_extract_static_user_pure_text_no_slots(self): + """Pure static user text without slots is merged into export.""" + generation = { + "system_message_text": "Short system.", + "user_message_text": "Always respond in a formal tone.", + } + result = build_responses_system_input(generation) + assert result == "Short system.\n\nAlways respond in a formal tone." + + def test_normalize_answer_scaffold_strips_with_citations(self): + """Answer scaffolds must not retain citation hints owned by OGX.""" + assert _normalize_answer_scaffold("Answer (max 150 words, with citations):") == "Answer (max 150 words):" + + def test_build_pattern_json_requires_generation_model_id(self): + """Malformed generation payloads must raise KeyError for required fields.""" + pattern = _make_pattern() + del pattern["settings"]["generation"]["model_id"] + with pytest.raises(KeyError): + build_pattern_json(pattern) + + def test_system_grounding_detection_requires_explicit_policy(self): + """Grounding detection must require explicit 'ONLY' constraint, not just persona.""" + generation_persona_only = { + "system_message_text": "You are a retrieval-augmented assistant. Use your best judgment.", + "user_message_text": "Answer ONLY using information from the documents below.\n{reference_documents}\n{question}", + } + result_persona = build_responses_system_input(generation_persona_only) + # Persona-only system should NOT trigger grounding suppression + assert "retrieval-augmented assistant" in result_persona + assert "use your best judgment" in result_persona.lower() + + generation_explicit = { + "system_message_text": "Answer using ONLY the provided documents.", + "user_message_text": "Answer ONLY using information from the documents below.\n{reference_documents}\n{question}", + } + result_explicit = build_responses_system_input(generation_explicit) + # Explicit grounding system SHOULD suppress redundant user grounding + # Both prompts are OGX-duplicative, so fallback is used + assert result_explicit == "You are a helpful assistant." + + def test_system_grounding_detection_uses_grounding_prefixes(self): + """Grounding detection must cover all ``_GROUNDING_PREFIXES`` entries.""" + generation = { + "system_message_text": "Answer ONLY using information from documents retrieved via file search.", + "user_message_text": ( + "Answer ONLY using information from the documents below.\n{reference_documents}\n{question}" + ), + } + result = build_responses_system_input(generation) + assert result == "You are a helpful assistant." + def test_preserves_existing_pattern_fields(self): """Existing pattern fields (name, chunking, embedding, etc.) must not be altered.""" pattern = _make_pattern() From 3cfa76eb002ee6116dc27755e97052964c0dab04 Mon Sep 17 00:00:00 2001 From: Lukasz Cmielowski Date: Tue, 30 Jun 2026 10:21:56 +0200 Subject: [PATCH 10/16] test(assets): update tests for PR #81 unified prompt templates - Replace language_autodetect=False with language='English' - Update test expectations to match unified RAG instruction format - Remove checks for deprecated model-specific fragments Signed-off-by: Lukasz Cmielowski Assisted-by: Cursor --- .../assets_generator/test_pattern_builder.py | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/tests/unit/ai4rag/assets_generator/test_pattern_builder.py b/tests/unit/ai4rag/assets_generator/test_pattern_builder.py index 54e5342..16a3a3a 100644 --- a/tests/unit/ai4rag/assets_generator/test_pattern_builder.py +++ b/tests/unit/ai4rag/assets_generator/test_pattern_builder.py @@ -244,7 +244,7 @@ def test_build_pattern_json_uses_export_parity_system_input(self): expected = build_responses_system_input( { "system_message_text": get_system_message_text(model_id), - "user_message_text": get_user_message_text(model_id, language_autodetect=False), + "user_message_text": get_user_message_text(model_id, language="English"), } ) @@ -252,7 +252,7 @@ def test_build_pattern_json_uses_export_parity_system_input(self): pattern["settings"]["generation"]["model_id"] = model_id pattern["settings"]["generation"]["system_message_text"] = get_system_message_text(model_id) pattern["settings"]["generation"]["user_message_text"] = get_user_message_text( - model_id, language_autodetect=False + model_id, language="English" ) build_pattern_json(pattern) @@ -261,9 +261,8 @@ def test_build_pattern_json_uses_export_parity_system_input(self): assert actual == expected assert actual != pattern["settings"]["generation"]["system_message_text"] assert "Granite Chat" in actual - assert "Retrieval Augmented Generation" in actual - assert "i.e.," in actual - assert "English only" in actual + assert "retrieval-augmented assistant" in actual + assert "You MUST respond in English" in actual @pytest.mark.parametrize( "model_id", @@ -279,7 +278,7 @@ def test_export_omits_ogx_duplicative_prompt_text(self, model_id: str): """Export must not duplicate citation/retrieval text that OGX injects at file_search runtime.""" generation = { "system_message_text": get_system_message_text(model_id), - "user_message_text": get_user_message_text(model_id, language_autodetect=False), + "user_message_text": get_user_message_text(model_id, language="English"), } system_text = build_responses_system_input(generation) @@ -314,21 +313,25 @@ def test_export_omits_ogx_config_yaml_instruction_text(self): assert system_text == "You are a retrieval-augmented assistant." @pytest.mark.parametrize( - ("model_id", "expected_fragment"), + "model_id", [ - ("meta-llama/llama-3-1-8b-instruct", "150 words"), - ("mistralai/mistral-large", "titles of documents"), - ("openai/gpt-oss-120b", "Answer Length: concise"), + "meta-llama/llama-3-1-8b-instruct", + "mistralai/mistral-large", + "openai/gpt-oss-120b", + "ibm/granite-3-8b-instruct", ], ) - def test_export_merges_family_specific_user_rules(self, model_id: str, expected_fragment: str): - """Each model family must preserve user-only rules in exported system input.""" + def test_export_merges_unified_rag_instructions(self, model_id: str): + """All model families use unified RAG instructions after PR #81.""" generation = { "system_message_text": get_system_message_text(model_id), - "user_message_text": get_user_message_text(model_id, language_autodetect=False), + "user_message_text": get_user_message_text(model_id, language="English"), } system_text = build_responses_system_input(generation) - assert expected_fragment in system_text + # All models now use unified RAG structure from PR #81 + assert "retrieval-augmented assistant" in system_text + assert "max 150 words" in system_text + assert "You MUST respond in English" in system_text def test_build_responses_system_input_handles_empty_inputs(self): """When both system and user are empty or contain only placeholders, return fallback.""" From f0f298bda9b40733d50df608c469112366fefcff Mon Sep 17 00:00:00 2001 From: Lukasz Cmielowski Date: Tue, 30 Jun 2026 10:26:38 +0200 Subject: [PATCH 11/16] style: fix black formatting Remove extra blank line between variable assignment and dict creation. Signed-off-by: Lukasz Cmielowski Assisted-by: Cursor --- ai4rag/components/assets_generator/pattern_builder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ai4rag/components/assets_generator/pattern_builder.py b/ai4rag/components/assets_generator/pattern_builder.py index bed3937..c3b95b4 100644 --- a/ai4rag/components/assets_generator/pattern_builder.py +++ b/ai4rag/components/assets_generator/pattern_builder.py @@ -437,7 +437,6 @@ def build_pattern_json( generation = pattern["settings"]["generation"] system_input = build_responses_system_input(generation) - pattern["settings"]["responses_template"] = { "model": generation["model_id"], "stream": False, From 8c910d987c600d8ac4e84d1613baf8e9e963c2d7 Mon Sep 17 00:00:00 2001 From: Lukasz Cmielowski Date: Tue, 30 Jun 2026 10:31:24 +0200 Subject: [PATCH 12/16] refactor(assets): extract OGX filtering logic to separate module MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extract 260 lines of prompt filtering logic from pattern_builder.py into new prompt_filters.py module for better separation of concerns. Changes: - Create ai4rag/components/assets_generator/prompt_filters.py - Move all OGX deduplication constants and functions - Add comprehensive module docstring - Make functions public (remove underscore prefixes) - Update pattern_builder.py (487 → 316 lines, -35%) - Import filtering functions from prompt_filters - Remove duplicated filtering logic - Focus on pattern building responsibilities - Update tests to import from new module Benefits: - Clear separation: pattern building vs text filtering - Better maintainability: OGX phrase lists in one place - Easier testing: filter logic independently testable - Reduced cognitive load when reviewing pattern code All 35 tests passing. Signed-off-by: Lukasz Cmielowski Assisted-by: Cursor --- .../assets_generator/pattern_builder.py | 205 ++------------ .../assets_generator/prompt_filters.py | 260 ++++++++++++++++++ .../assets_generator/test_pattern_builder.py | 4 +- 3 files changed, 279 insertions(+), 190 deletions(-) create mode 100644 ai4rag/components/assets_generator/prompt_filters.py diff --git a/ai4rag/components/assets_generator/pattern_builder.py b/ai4rag/components/assets_generator/pattern_builder.py index c3b95b4..0addd34 100644 --- a/ai4rag/components/assets_generator/pattern_builder.py +++ b/ai4rag/components/assets_generator/pattern_builder.py @@ -2,14 +2,15 @@ # Copyright IBM Corp. 2026 # SPDX-License-Identifier: Apache-2.0 # ----------------------------------------------------------------------------- -"""Map HPO Chat Completions prompts to exported Responses ``input[system]`` text. - -OGX-owned phrases are defined below and must stay aligned with -``benchmarking/rag/config.yaml`` (``file_search_params``, ``context_prompt_params``, -``annotation_prompt_params``). If OGX injection strings change, update the lists here. -""" - -import re +"""Build Responses API pattern definitions from HPO experiment results.""" + +from ai4rag.components.assets_generator.prompt_filters import ( + GROUNDING_PREFIXES, + USER_GROUNDING_SKIP_PREFIXES, + USER_RAG_GROUNDING_PREFIXES, + is_citation_related_line, + strip_ogx_runtime_instructions, +) _USER_QUERY_PLACEHOLDER = "" _EMPTY_SYSTEM_FALLBACK = "You are a helpful assistant." @@ -18,183 +19,11 @@ # Suffix lines after ``{reference_documents}``: drop structural wrappers (e.g. ``[End]``). _DOCUMENT_SLOT_MARKERS = frozenset({"[Document]", "[End]", "Documents:", "Context:"}) -# ============================================================================ -# OGX Runtime Injection Strings -# ============================================================================ -# These phrases are injected by OGX at file_search runtime via -# benchmarking/rag/config.yaml (file_search_params, context_prompt_params, -# annotation_prompt_params). HPO export must NOT duplicate them in -# responses_template.input[system]. -# -# If OGX changes injection strings in config.yaml, update these lists. -# ============================================================================ - -# Citation-related phrases -_CITATION_PREFIXES = ( - "You MUST cite sources", - "Cite sources immediately", -) -_CITATION_SUBSTRINGS = ( - "[1], [2]", - "<|file-id|>", - "cite as <|", - "file citations", - "document numbers for every factual claim", -) -_HPO_CITATION_INSTRUCTION = ( - "You MUST cite sources using [1], [2], etc. matching the document numbers for every factual claim." -) -_HPO_CITATION_FRAGMENTS = ( - _HPO_CITATION_INSTRUCTION, - "You MUST cite sources using [1], [2], etc.", - "You MUST cite sources using [1], [2].", -) - -# Grounding/retrieval-related phrases -_GROUNDING_PREFIXES = ( - "Answer ONLY using information from the documents", - "Answer ONLY using information from documents retrieved", - "Answer using ONLY the provided documents", - "Answer using ONLY information from documents", - "Do not use outside knowledge", - "If the retrieved documents do not contain", - "If the documents do not contain", -) -_GROUNDING_SUBSTRINGS = ( - "documents below", - "retrieved via file search", - "retrieved to help answer the user", - "supporting information only in answering", -) -_SYSTEM_GROUNDING_PHRASES = ( - "Answer using ONLY the provided documents.", - "Answer using ONLY information from documents retrieved via file search.", -) - -# File search tool markers -_FILE_SEARCH_MARKERS = ( - "file_search tool found", - "BEGIN of file_search tool results", - "END of file_search tool results", - "The above results were retrieved to help answer", - "Use them as supporting information only", - "Do not add extra punctuation. Use only the file IDs", -) - -# User template duplicate detection (pass 1 filtering) -_USER_GROUNDING_SKIP_PREFIXES = ( - "Answer ONLY using information from the documents below", - "Do not use outside knowledge", - "If the documents do not contain the answer", -) -_USER_RAG_GROUNDING_PREFIXES = ( - "You are a specialized Retrieval Augmented Generation", - "Prioritize correctness and ensure your response is grounded", -) - # Document and question slot markers _DOCUMENT_LABELS = ("Documents:", "Context:", "[Document]") _QUESTION_PREFIXES = ("Question:", "Q:", "[conversation]:") _LEGACY_DOCUMENT_MARKERS = ("Documents:\n", "Context:\n", "[Document]\n") -# Combined line prefixes for sentence-level filtering -_OGX_DUPLICATIVE_LINE_PREFIXES = _CITATION_PREFIXES + _GROUNDING_PREFIXES + _FILE_SEARCH_MARKERS - -# Combined substrings for partial-match filtering -_OGX_DUPLICATIVE_SUBSTRINGS = _CITATION_SUBSTRINGS + _GROUNDING_SUBSTRINGS - - -def _collapse_whitespace(text: str) -> str: - """Collapse repeated interior spaces after phrase removal.""" - return re.sub(r" +", " ", text).strip() - - -def _sentence_is_ogx_duplicative(sentence: str) -> bool: - """Return whether a sentence duplicates OGX file_search runtime injection.""" - stripped = sentence.strip().rstrip(".") - if not stripped: - return True - if any(stripped.startswith(prefix.rstrip(".")) for prefix in _OGX_DUPLICATIVE_LINE_PREFIXES): - return True - normalized = stripped.lower() - return any(fragment.lower() in normalized for fragment in _OGX_DUPLICATIVE_SUBSTRINGS) - - -def _is_citation_related_line(line: str) -> bool: - """Return whether an entire line should be dropped as citation guidance.""" - stripped = line.strip() - if not stripped: - return False - lower = stripped.lower() - if any(stripped.startswith(prefix) for prefix in _CITATION_PREFIXES): - return True - if any(fragment.lower() in lower for fragment in _HPO_CITATION_FRAGMENTS): - return True - return any(sub.lower() in lower for sub in _CITATION_SUBSTRINGS) - - -def _filter_ogx_duplicative_sentences(line: str) -> str: - """Remove OGX-duplicative sentences while keeping persona or policy sentences.""" - stripped = line.strip() - if not stripped or _is_citation_related_line(stripped): - return "" - - # Split on ". " only — avoids breaking abbreviations such as "i.e.," - parts = [part.strip() for part in stripped.split(". ") if part.strip()] - if len(parts) <= 1: - if _sentence_is_ogx_duplicative(stripped.rstrip(".")): - return "" - return stripped - - kept = [part.rstrip(".") for part in parts if not _sentence_is_ogx_duplicative(part.rstrip("."))] - if not kept: - return "" - - result = ". ".join(kept) - if stripped.endswith("."): - result += "." - return result - - -def _normalize_answer_scaffold(line: str) -> str: - """Drop citation hints from answer scaffolds; OGX owns citation via annotations.""" - normalized = line.replace(", with citations", "").replace("with citations", "") - return _collapse_whitespace(normalized) - - -def _strip_ogx_runtime_instructions(text: str) -> str: - """Remove text that OGX injects via file_search config at inference time.""" - if not text.strip(): - return "" - - for phrase in _SYSTEM_GROUNDING_PHRASES: - text = text.replace(phrase, "").replace(phrase.rstrip("."), "") - text = _collapse_whitespace(text) - - lines: list[str] = [] - for line in text.splitlines(): - stripped = line.strip() - if not stripped: - if lines and lines[-1] != "": - lines.append("") - continue - if _is_citation_related_line(stripped): - continue - - cleaned = _filter_ogx_duplicative_sentences(stripped) - for fragment in _HPO_CITATION_FRAGMENTS: - if fragment in cleaned: - cleaned = cleaned.replace(fragment, "").strip() - break - cleaned = _normalize_answer_scaffold(cleaned) - if cleaned: - lines.append(cleaned) - - result = "\n".join(lines) - while "\n\n\n" in result: - result = result.replace("\n\n\n", "\n\n") - return result.strip() - def _join_answer_scaffold_blocks(lines: list[str]) -> str: """Group lines into paragraph blocks, starting a new block when an answer-scaffold line appears. @@ -222,18 +51,18 @@ def _join_answer_scaffold_blocks(lines: list[str]) -> str: def _should_skip_redundant_user_line(stripped: str, system_has_grounding: bool) -> bool: """Return whether a user-template line duplicates system policy for export.""" - if _is_citation_related_line(stripped): + if is_citation_related_line(stripped): return True return system_has_grounding and any( - stripped.startswith(prefix) for prefix in _GROUNDING_PREFIXES + _USER_RAG_GROUNDING_PREFIXES + stripped.startswith(prefix) for prefix in GROUNDING_PREFIXES + USER_RAG_GROUNDING_PREFIXES ) def _should_skip_user_export_line(stripped: str) -> bool: """Return whether a merged user line is OGX-owned and must not be exported.""" - if any(stripped.startswith(prefix) for prefix in _USER_GROUNDING_SKIP_PREFIXES): + if any(stripped.startswith(prefix) for prefix in USER_GROUNDING_SKIP_PREFIXES): return True - return _is_citation_related_line(stripped) + return is_citation_related_line(stripped) def _strip_document_slot_prefix(prefix: str) -> str: @@ -296,12 +125,12 @@ def _system_has_grounding_policy(system: str) -> bool: """Return whether the system prompt already states an explicit document-only grounding rule. Uses the same prefix list as sentence-level filtering so that adding a new - OGX phrase to ``_GROUNDING_PREFIXES`` automatically covers system detection too. + OGX phrase to ``GROUNDING_PREFIXES`` automatically covers system detection too. Does NOT match descriptive personas like "retrieval-augmented assistant" without an explicit grounding constraint. """ normalized = system.lower() - return any(prefix.lower() in normalized for prefix in _GROUNDING_PREFIXES) + return any(prefix.lower() in normalized for prefix in GROUNDING_PREFIXES) def _filter_static_user_for_responses(system: str, static_user: str) -> str: @@ -328,7 +157,7 @@ def _filter_static_user_for_responses(system: str, static_user: str) -> str: def _adapt_system_for_responses_export(system: str) -> str: """Drop OGX-runtime retrieval/citation text from the HPO system prompt.""" - return _strip_ogx_runtime_instructions(system) + return strip_ogx_runtime_instructions(system) def _adapt_static_user_for_responses_export(static_user: str) -> str: @@ -344,7 +173,7 @@ def _adapt_static_user_for_responses_export(static_user: str) -> str: stripped = line.strip() if not stripped or _should_skip_user_export_line(stripped): continue - cleaned = _strip_ogx_runtime_instructions(stripped) + cleaned = strip_ogx_runtime_instructions(stripped) if cleaned: adapted_lines.append(cleaned) diff --git a/ai4rag/components/assets_generator/prompt_filters.py b/ai4rag/components/assets_generator/prompt_filters.py new file mode 100644 index 0000000..2157c3e --- /dev/null +++ b/ai4rag/components/assets_generator/prompt_filters.py @@ -0,0 +1,260 @@ +# ----------------------------------------------------------------------------- +# Copyright IBM Corp. 2026 +# SPDX-License-Identifier: Apache-2.0 +# ----------------------------------------------------------------------------- +"""Filter HPO prompts to remove OGX runtime injection duplicates. + +OGX (OpenSearch GenAI eXperience) injects grounding, citation, and retrieval +instructions at runtime via benchmarking/rag/config.yaml. HPO (HyperParameter +Optimization) templates sometimes include similar phrases that must be removed +during Responses API export to avoid duplication. + +This module provides filtering functions to strip OGX-owned content while +preserving HPO-specific persona, policy, and answer formatting rules. + +Note +---- +OGX phrase lists must stay synchronized with benchmarking/rag/config.yaml. +If OGX updates their injection strings, update the constants below. +""" + +import re + +# ============================================================================ +# OGX Runtime Injection Strings +# ============================================================================ +# Source: benchmarking/rag/config.yaml +# These phrases are injected by OGX at file_search runtime. +# Export must NOT duplicate them in responses_template.input[system]. +# ============================================================================ + +# Citation-related phrases +CITATION_PREFIXES = ( + "You MUST cite sources", + "Cite sources immediately", +) +CITATION_SUBSTRINGS = ( + "[1], [2]", + "<|file-id|>", + "cite as <|", + "file citations", + "document numbers for every factual claim", +) +HPO_CITATION_INSTRUCTION = ( + "You MUST cite sources using [1], [2], etc. matching the document numbers for every factual claim." +) +HPO_CITATION_FRAGMENTS = ( + HPO_CITATION_INSTRUCTION, + "You MUST cite sources using [1], [2], etc.", + "You MUST cite sources using [1], [2].", +) + +# Grounding/retrieval-related phrases +GROUNDING_PREFIXES = ( + "Answer ONLY using information from the documents", + "Answer ONLY using information from documents retrieved", + "Answer using ONLY the provided documents", + "Answer using ONLY information from documents", + "Do not use outside knowledge", + "If the retrieved documents do not contain", + "If the documents do not contain", +) +GROUNDING_SUBSTRINGS = ( + "documents below", + "retrieved via file search", + "retrieved to help answer the user", + "supporting information only in answering", +) +SYSTEM_GROUNDING_PHRASES = ( + "Answer using ONLY the provided documents.", + "Answer using ONLY information from documents retrieved via file search.", +) + +# File search tool markers +FILE_SEARCH_MARKERS = ( + "file_search tool found", + "BEGIN of file_search tool results", + "END of file_search tool results", + "The above results were retrieved to help answer", + "Use them as supporting information only", + "Do not add extra punctuation. Use only the file IDs", +) + +# User template duplicate detection (pass 1 filtering) +USER_GROUNDING_SKIP_PREFIXES = ( + "Answer ONLY using information from the documents below", + "Do not use outside knowledge", + "If the documents do not contain the answer", +) +USER_RAG_GROUNDING_PREFIXES = ( + "You are a specialized Retrieval Augmented Generation", + "Prioritize correctness and ensure your response is grounded", +) + +# Combined line prefixes for sentence-level filtering +OGX_DUPLICATIVE_LINE_PREFIXES = CITATION_PREFIXES + GROUNDING_PREFIXES + FILE_SEARCH_MARKERS + +# Combined substrings for partial-match filtering +OGX_DUPLICATIVE_SUBSTRINGS = CITATION_SUBSTRINGS + GROUNDING_SUBSTRINGS + + +def collapse_whitespace(text: str) -> str: + """Collapse repeated interior spaces after phrase removal. + + Parameters + ---------- + text : str + Text potentially containing multiple consecutive spaces. + + Returns + ------- + str + Text with interior whitespace collapsed to single spaces, stripped. + """ + return re.sub(r" +", " ", text).strip() + + +def is_sentence_ogx_duplicative(sentence: str) -> bool: + """Return whether a sentence duplicates OGX file_search runtime injection. + + Parameters + ---------- + sentence : str + Single sentence to check. + + Returns + ------- + bool + True if sentence matches OGX injection patterns. + """ + stripped = sentence.strip().rstrip(".") + if not stripped: + return True + if any(stripped.startswith(prefix.rstrip(".")) for prefix in OGX_DUPLICATIVE_LINE_PREFIXES): + return True + normalized = stripped.lower() + return any(fragment.lower() in normalized for fragment in OGX_DUPLICATIVE_SUBSTRINGS) + + +def is_citation_related_line(line: str) -> bool: + """Return whether an entire line should be dropped as citation guidance. + + Parameters + ---------- + line : str + Line of text to check. + + Returns + ------- + bool + True if line contains only citation instructions owned by OGX. + """ + stripped = line.strip() + if not stripped: + return False + lower = stripped.lower() + if any(stripped.startswith(prefix) for prefix in CITATION_PREFIXES): + return True + if any(fragment.lower() in lower for fragment in HPO_CITATION_FRAGMENTS): + return True + return any(sub.lower() in lower for sub in CITATION_SUBSTRINGS) + + +def filter_ogx_duplicative_sentences(line: str) -> str: + """Remove OGX-duplicative sentences while keeping persona or policy sentences. + + Handles multi-sentence lines by filtering at sentence granularity. + + Parameters + ---------- + line : str + Line potentially containing multiple sentences. + + Returns + ------- + str + Line with OGX-duplicative sentences removed, or empty string if all filtered. + """ + stripped = line.strip() + if not stripped or is_citation_related_line(stripped): + return "" + + # Split on ". " only — avoids breaking abbreviations such as "i.e.," + parts = [part.strip() for part in stripped.split(". ") if part.strip()] + if len(parts) <= 1: + if is_sentence_ogx_duplicative(stripped.rstrip(".")): + return "" + return stripped + + kept = [part.rstrip(".") for part in parts if not is_sentence_ogx_duplicative(part.rstrip("."))] + if not kept: + return "" + + result = ". ".join(kept) + if stripped.endswith("."): + result += "." + return result + + +def normalize_answer_scaffold(line: str) -> str: + """Drop citation hints from answer scaffolds; OGX owns citation via annotations. + + Parameters + ---------- + line : str + Line potentially containing answer scaffold with citation hints. + + Returns + ------- + str + Line with ", with citations" and "with citations" removed, whitespace normalized. + """ + normalized = line.replace(", with citations", "").replace("with citations", "") + return collapse_whitespace(normalized) + + +def strip_ogx_runtime_instructions(text: str) -> str: + """Remove text that OGX injects via file_search config at inference time. + + This is the main filtering function that orchestrates all OGX deduplication. + + Parameters + ---------- + text : str + Raw HPO prompt text (system or user message). + + Returns + ------- + str + Filtered text with OGX-duplicative content removed. + """ + if not text.strip(): + return "" + + for phrase in SYSTEM_GROUNDING_PHRASES: + text = text.replace(phrase, "").replace(phrase.rstrip("."), "") + text = collapse_whitespace(text) + + lines: list[str] = [] + for line in text.splitlines(): + stripped = line.strip() + if not stripped: + if lines and lines[-1] != "": + lines.append("") + continue + if is_citation_related_line(stripped): + continue + + cleaned = filter_ogx_duplicative_sentences(stripped) + for fragment in HPO_CITATION_FRAGMENTS: + if fragment in cleaned: + cleaned = cleaned.replace(fragment, "").strip() + break + cleaned = normalize_answer_scaffold(cleaned) + if cleaned: + lines.append(cleaned) + + result = "\n".join(lines) + while "\n\n\n" in result: + result = result.replace("\n\n\n", "\n\n") + return result.strip() diff --git a/tests/unit/ai4rag/assets_generator/test_pattern_builder.py b/tests/unit/ai4rag/assets_generator/test_pattern_builder.py index 16a3a3a..8338c0c 100644 --- a/tests/unit/ai4rag/assets_generator/test_pattern_builder.py +++ b/tests/unit/ai4rag/assets_generator/test_pattern_builder.py @@ -11,9 +11,9 @@ from ai4rag.components.assets_generator import build_pattern_json from ai4rag.components.assets_generator.pattern_builder import ( _is_placeholder_only_export, - _normalize_answer_scaffold, build_responses_system_input, ) +from ai4rag.components.assets_generator.prompt_filters import normalize_answer_scaffold from ai4rag.search_space.src.model_props import get_system_message_text, get_user_message_text # --------------------------------------------------------------------------- @@ -409,7 +409,7 @@ def test_extract_static_user_pure_text_no_slots(self): def test_normalize_answer_scaffold_strips_with_citations(self): """Answer scaffolds must not retain citation hints owned by OGX.""" - assert _normalize_answer_scaffold("Answer (max 150 words, with citations):") == "Answer (max 150 words):" + assert normalize_answer_scaffold("Answer (max 150 words, with citations):") == "Answer (max 150 words):" def test_build_pattern_json_requires_generation_model_id(self): """Malformed generation payloads must raise KeyError for required fields.""" From 260fb90f3584e25610c19fad3f04464aa7e18dd2 Mon Sep 17 00:00:00 2001 From: Lukasz Cmielowski Date: Tue, 30 Jun 2026 12:53:19 +0200 Subject: [PATCH 13/16] refactor(assets): apply all code review fixes and simplifications Critical fixes: - Remove legacy template support code (only support {reference_documents} format) - Fix _system_has_grounding_policy to use sentence-level startswith instead of substring Prevents false positives like "All documents do not contain PII" - Add None guards for temperature and max_completion_tokens Prevents sending null values to Responses API - Replace while loop with single-pass regex in prompt_filters.py Test coverage: - Add test_omits_temperature_when_none - Add test_omits_max_output_tokens_when_none - Add test_system_grounding_detection_no_false_positive_on_embedded_substring - Update test_extract_static_user_pure_text_no_slots for modern templates only - Rename test_build_responses_system_input_legacy_user_prefix to test_build_responses_system_input_strips_ogx_prefix All 38 tests passing. Black formatting passing. Signed-off-by: Lukasz Cmielowski Assisted-by: Cursor --- .../assets_generator/pattern_builder.py | 40 ++++++++-------- .../assets_generator/prompt_filters.py | 3 +- .../assets_generator/test_pattern_builder.py | 46 ++++++++++++++++--- 3 files changed, 61 insertions(+), 28 deletions(-) diff --git a/ai4rag/components/assets_generator/pattern_builder.py b/ai4rag/components/assets_generator/pattern_builder.py index 0addd34..95c3c71 100644 --- a/ai4rag/components/assets_generator/pattern_builder.py +++ b/ai4rag/components/assets_generator/pattern_builder.py @@ -4,6 +4,8 @@ # ----------------------------------------------------------------------------- """Build Responses API pattern definitions from HPO experiment results.""" +import re + from ai4rag.components.assets_generator.prompt_filters import ( GROUNDING_PREFIXES, USER_GROUNDING_SKIP_PREFIXES, @@ -22,7 +24,6 @@ # Document and question slot markers _DOCUMENT_LABELS = ("Documents:", "Context:", "[Document]") _QUESTION_PREFIXES = ("Question:", "Q:", "[conversation]:") -_LEGACY_DOCUMENT_MARKERS = ("Documents:\n", "Context:\n", "[Document]\n") def _join_answer_scaffold_blocks(lines: list[str]) -> str: @@ -111,16 +112,6 @@ def _extract_static_user_from_reference_slot(text: str) -> str: return "\n\n".join(parts).strip() -def _extract_static_user_without_reference_slot(text: str) -> str: - """Extract static instructions from legacy templates without an explicit doc slot.""" - doc_idx = len(text) - for marker in _LEGACY_DOCUMENT_MARKERS: - idx = text.find(marker) - if idx != -1: - doc_idx = min(doc_idx, idx) - return text[:doc_idx].strip() - - def _system_has_grounding_policy(system: str) -> bool: """Return whether the system prompt already states an explicit document-only grounding rule. @@ -128,9 +119,11 @@ def _system_has_grounding_policy(system: str) -> bool: OGX phrase to ``GROUNDING_PREFIXES`` automatically covers system detection too. Does NOT match descriptive personas like "retrieval-augmented assistant" without an explicit grounding constraint. + + Checks at sentence granularity to avoid false positives from embedded substrings. """ - normalized = system.lower() - return any(prefix.lower() in normalized for prefix in GROUNDING_PREFIXES) + sentences = [s.strip() for s in re.split(r"(?<=[.!?])\s+", system)] + return any(any(sent.lower().startswith(p.lower()) for p in GROUNDING_PREFIXES) for sent in sentences) def _filter_static_user_for_responses(system: str, static_user: str) -> str: @@ -185,16 +178,17 @@ def _extract_static_user_instructions(user_message_text: str) -> str: Strips runtime slots (retrieved documents, question) that Responses API supplies via ``file_search`` and the user ``input`` message respectively. + + All current templates use {reference_documents} placeholder format. """ if not user_message_text: return "" text = str(user_message_text).strip() - if "{reference_documents}" in text: - return _extract_static_user_from_reference_slot(text) + if "{reference_documents}" not in text: + return "" - prefix = _extract_static_user_without_reference_slot(text) - return prefix + return _extract_static_user_from_reference_slot(text) def _is_placeholder_only_export(text: str) -> bool: @@ -266,7 +260,7 @@ def build_pattern_json( generation = pattern["settings"]["generation"] system_input = build_responses_system_input(generation) - pattern["settings"]["responses_template"] = { + responses_template = { "model": generation["model_id"], "stream": False, "store": False, @@ -277,8 +271,6 @@ def build_pattern_json( }, {"content": [{"text": _USER_QUERY_PLACEHOLDER, "type": "input_text"}], "role": "user"}, ], - "max_output_tokens": generation["max_completion_tokens"], - "temperature": generation["temperature"], "tool_choice": {"mode": "required", "tools": [{}], "type": "file_search"}, "tools": [ { @@ -290,6 +282,14 @@ def build_pattern_json( "include": ["file_search_call.results"], } + # Only include temperature and max_output_tokens if they are not None + if generation.get("temperature") is not None: + responses_template["temperature"] = generation["temperature"] + if generation.get("max_completion_tokens") is not None: + responses_template["max_output_tokens"] = generation["max_completion_tokens"] + + pattern["settings"]["responses_template"] = responses_template + retrieval_settings = pattern["settings"]["retrieval"] search_mode = retrieval_settings.get("search_mode") ranker_strategy = retrieval_settings.get("ranker_strategy") diff --git a/ai4rag/components/assets_generator/prompt_filters.py b/ai4rag/components/assets_generator/prompt_filters.py index 2157c3e..861cd79 100644 --- a/ai4rag/components/assets_generator/prompt_filters.py +++ b/ai4rag/components/assets_generator/prompt_filters.py @@ -255,6 +255,5 @@ def strip_ogx_runtime_instructions(text: str) -> str: lines.append(cleaned) result = "\n".join(lines) - while "\n\n\n" in result: - result = result.replace("\n\n\n", "\n\n") + result = re.sub(r"\n{3,}", "\n\n", result) return result.strip() diff --git a/tests/unit/ai4rag/assets_generator/test_pattern_builder.py b/tests/unit/ai4rag/assets_generator/test_pattern_builder.py index 8338c0c..6154052 100644 --- a/tests/unit/ai4rag/assets_generator/test_pattern_builder.py +++ b/tests/unit/ai4rag/assets_generator/test_pattern_builder.py @@ -221,7 +221,7 @@ def test_export_system_input_skips_duplicate_citation_and_keeps_answer_scaffold( assert "max 150 words" in system_text assert "English only" in system_text - def test_build_responses_system_input_legacy_user_prefix(self): + def test_build_responses_system_input_strips_ogx_prefix(self): """Legacy grounding and citation lines are omitted; persona supplements are kept.""" generation = { "system_message_text": "Short system prefix.", @@ -251,9 +251,7 @@ def test_build_pattern_json_uses_export_parity_system_input(self): pattern = _make_pattern() pattern["settings"]["generation"]["model_id"] = model_id pattern["settings"]["generation"]["system_message_text"] = get_system_message_text(model_id) - pattern["settings"]["generation"]["user_message_text"] = get_user_message_text( - model_id, language="English" - ) + pattern["settings"]["generation"]["user_message_text"] = get_user_message_text(model_id, language="English") build_pattern_json(pattern) @@ -399,13 +397,14 @@ def test_user_grounding_merges_when_system_is_persona_only(self): assert "specialized Retrieval Augmented Generation" in result def test_extract_static_user_pure_text_no_slots(self): - """Pure static user text without slots is merged into export.""" + """Templates without {reference_documents} are invalid and return empty user text.""" generation = { "system_message_text": "Short system.", "user_message_text": "Always respond in a formal tone.", } result = build_responses_system_input(generation) - assert result == "Short system.\n\nAlways respond in a formal tone." + # Invalid template (no {reference_documents}) → system only + assert result == "Short system." def test_normalize_answer_scaffold_strips_with_citations(self): """Answer scaffolds must not retain citation hints owned by OGX.""" @@ -459,3 +458,38 @@ def test_preserves_existing_pattern_fields(self): assert pattern["name"] == original_name assert pattern["settings"]["chunking"] == original_chunking + + def test_omits_temperature_when_none(self): + """Temperature field must be omitted when None to avoid sending null to API.""" + pattern = _make_pattern() + pattern["settings"]["generation"]["temperature"] = None + + build_pattern_json(pattern) + + assert "temperature" not in pattern["settings"]["responses_template"] + # max_output_tokens should still be present + assert "max_output_tokens" in pattern["settings"]["responses_template"] + + def test_omits_max_output_tokens_when_none(self): + """max_output_tokens field must be omitted when None to avoid sending null to API.""" + pattern = _make_pattern() + pattern["settings"]["generation"]["max_completion_tokens"] = None + + build_pattern_json(pattern) + + assert "max_output_tokens" not in pattern["settings"]["responses_template"] + # temperature should still be present + assert "temperature" in pattern["settings"]["responses_template"] + + def test_system_grounding_detection_no_false_positive_on_embedded_substring(self): + """Grounding detection must not match embedded substrings, only sentence prefixes.""" + generation = { + "system_message_text": "Use only relevant information. All documents do not contain PII.", + "user_message_text": "Answer ONLY using information from the documents below.\n{reference_documents}\n{question}", + } + result = build_responses_system_input(generation) + + # "documents do not contain" is in _GROUNDING_PREFIXES but appears mid-sentence + # Should NOT suppress user grounding since system doesn't start with a grounding prefix + assert "Use only relevant information" in result + assert "All documents do not contain PII" in result From f95e919b4faa35d2d0506e581f12453bdd4b7706 Mon Sep 17 00:00:00 2001 From: Lukasz Cmielowski Date: Wed, 1 Jul 2026 12:52:17 +0200 Subject: [PATCH 14/16] fix(chunking): convert headings list to string in docling chunker Change headings metadata from list to string format using " > " separator. This makes the metadata more readable and consistent with other formats. Before: metadata["headings"] = ["Chapter 1", "Section 1.1"] After: metadata["headings"] = "Chapter 1 > Section 1.1" Signed-off-by: Lukasz Cmielowski Assisted-by: Cursor --- ai4rag/rag/chunking/docling_chunker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ai4rag/rag/chunking/docling_chunker.py b/ai4rag/rag/chunking/docling_chunker.py index bcf50f9..4aedb16 100644 --- a/ai4rag/rag/chunking/docling_chunker.py +++ b/ai4rag/rag/chunking/docling_chunker.py @@ -100,7 +100,7 @@ def split_documents(self, documents: Sequence[DoclingDocument]) -> list[AI4RAGCh } if chunk.meta.headings: - metadata["headings"] = chunk.meta.headings + metadata["headings"] = " > ".join(chunk.meta.headings) all_chunks.append(AI4RAGChunk(text=text, metadata=metadata)) From a9d89e4bc0788867de035a9ba2b3a8fd52aaa069 Mon Sep 17 00:00:00 2001 From: Lukasz Cmielowski Date: Wed, 1 Jul 2026 12:59:35 +0200 Subject: [PATCH 15/16] test(chunking): update tests for string-based headings format Update tests to handle headings as string ("A > B") instead of list. Extract first heading from the string using .split(" > ")[0]. Signed-off-by: Lukasz Cmielowski Assisted-by: Cursor --- .../unit/ai4rag/rag/chunking/test_docling_chunker.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/unit/ai4rag/rag/chunking/test_docling_chunker.py b/tests/unit/ai4rag/rag/chunking/test_docling_chunker.py index ffbb3af..47ea827 100644 --- a/tests/unit/ai4rag/rag/chunking/test_docling_chunker.py +++ b/tests/unit/ai4rag/rag/chunking/test_docling_chunker.py @@ -78,20 +78,24 @@ def test_contextualize_true_includes_headings(self, doc_with_sections): heading_chunks = [c for c in chunks if c.metadata.get("headings")] assert len(heading_chunks) > 0 for chunk in heading_chunks: - assert chunk.metadata["headings"][0] in chunk.text + # headings is now a string like "Introduction > Section 1" + first_heading = chunk.metadata["headings"].split(" > ")[0] + assert first_heading in chunk.text def test_contextualize_false_excludes_headings(self, doc_with_sections): chunker = DoclingChunker(contextualize=False) chunks = chunker.split_documents([doc_with_sections]) heading_chunks = [c for c in chunks if c.metadata.get("headings")] for chunk in heading_chunks: - heading = chunk.metadata["headings"][0] - assert not chunk.text.startswith(heading) + # headings is now a string like "Introduction > Section 1" + first_heading = chunk.metadata["headings"].split(" > ")[0] + assert not chunk.text.startswith(first_heading) def test_headings_in_metadata(self, chunker, doc_with_sections): chunks = chunker.split_documents([doc_with_sections]) heading_chunks = [c for c in chunks if c.metadata.get("headings")] - heading_values = [c.metadata["headings"][0] for c in heading_chunks] + # headings is now a string like "Introduction" or "Introduction > Section 1" + heading_values = [c.metadata["headings"].split(" > ")[0] for c in heading_chunks] assert "Introduction" in heading_values assert "Methods" in heading_values assert "Results" in heading_values From e5e060c311755375f03be96befdfd45041fec85a Mon Sep 17 00:00:00 2001 From: Lukasz Cmielowski Date: Wed, 1 Jul 2026 14:20:52 +0200 Subject: [PATCH 16/16] fix(assets): apply critical and recommended fixes from PR #77 review Applied 5 fixes from Mateusz's review comments: 1. Fix tool_choice schema (CRITICAL) - Changed from: {"mode": "required", "tools": [{}], "type": "file_search"} - Changed to: {"type": "file_search"} - Matches OGX SDK authoritative type definition 2. Update test assertion for tool_choice - Updated test to match corrected schema 3. Remove citation constant duplication (DRY) - Import _RAG_CITATION_INSTRUCTION from model_props.py - Remove duplicate HPO_CITATION_INSTRUCTION - Single source of truth for citation instruction 4. Document prefix constants - Add inline comments explaining usage of each constant group - Documents two-pass filtering architecture - Helps contributors understand which list to update 5. Fix normalize_answer_scaffold regex - Use regex to handle all comma orderings - Handles both ", with citations" and "with citations," correctly - No orphaned commas in output All 89 tests passing. Co-Authored-By: Claude Sonnet 4.5 Signed-off-by: Lukasz Cmielowski Assisted-by: Cursor --- .../assets_generator/pattern_builder.py | 2 +- .../assets_generator/prompt_filters.py | 21 +++++++++++++------ .../assets_generator/test_pattern_builder.py | 2 +- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/ai4rag/components/assets_generator/pattern_builder.py b/ai4rag/components/assets_generator/pattern_builder.py index 95c3c71..a0dab8b 100644 --- a/ai4rag/components/assets_generator/pattern_builder.py +++ b/ai4rag/components/assets_generator/pattern_builder.py @@ -271,7 +271,7 @@ def build_pattern_json( }, {"content": [{"text": _USER_QUERY_PLACEHOLDER, "type": "input_text"}], "role": "user"}, ], - "tool_choice": {"mode": "required", "tools": [{}], "type": "file_search"}, + "tool_choice": {"type": "file_search"}, "tools": [ { "type": "file_search", diff --git a/ai4rag/components/assets_generator/prompt_filters.py b/ai4rag/components/assets_generator/prompt_filters.py index 861cd79..efceb27 100644 --- a/ai4rag/components/assets_generator/prompt_filters.py +++ b/ai4rag/components/assets_generator/prompt_filters.py @@ -20,6 +20,8 @@ import re +from ai4rag.search_space.src.model_props import _RAG_CITATION_INSTRUCTION + # ============================================================================ # OGX Runtime Injection Strings # ============================================================================ @@ -40,16 +42,16 @@ "file citations", "document numbers for every factual claim", ) -HPO_CITATION_INSTRUCTION = ( - "You MUST cite sources using [1], [2], etc. matching the document numbers for every factual claim." -) +# HPO citation fragments for filtering (uses _RAG_CITATION_INSTRUCTION from model_props) HPO_CITATION_FRAGMENTS = ( - HPO_CITATION_INSTRUCTION, + _RAG_CITATION_INSTRUCTION, "You MUST cite sources using [1], [2], etc.", "You MUST cite sources using [1], [2].", ) # Grounding/retrieval-related phrases +# Used in: sentence-level filtering (sentence_is_ogx_duplicative) and +# system grounding detection (_system_has_grounding_policy in pattern_builder.py) GROUNDING_PREFIXES = ( "Answer ONLY using information from the documents", "Answer ONLY using information from documents retrieved", @@ -59,18 +61,21 @@ "If the retrieved documents do not contain", "If the documents do not contain", ) +# Used in: substring matching within sentences for partial phrase detection GROUNDING_SUBSTRINGS = ( "documents below", "retrieved via file search", "retrieved to help answer the user", "supporting information only in answering", ) +# Used in: whole-phrase removal from system prompts (strip_ogx_runtime_instructions) SYSTEM_GROUNDING_PHRASES = ( "Answer using ONLY the provided documents.", "Answer using ONLY information from documents retrieved via file search.", ) # File search tool markers +# Used in: detecting OGX tool result wrappers in sentence-level filtering FILE_SEARCH_MARKERS = ( "file_search tool found", "BEGIN of file_search tool results", @@ -80,12 +85,16 @@ "Do not add extra punctuation. Use only the file IDs", ) -# User template duplicate detection (pass 1 filtering) +# User template duplicate detection +# Used in: Pass 2 filtering (_should_skip_user_export_line in pattern_builder.py) +# OGX-owned lines that must never be exported regardless of system prompt content USER_GROUNDING_SKIP_PREFIXES = ( "Answer ONLY using information from the documents below", "Do not use outside knowledge", "If the documents do not contain the answer", ) +# Used in: Pass 1 filtering (_should_skip_redundant_user_line in pattern_builder.py) +# Only suppressed when system prompt already has grounding policy to avoid duplication USER_RAG_GROUNDING_PREFIXES = ( "You are a specialized Retrieval Augmented Generation", "Prioritize correctness and ensure your response is grounded", @@ -209,7 +218,7 @@ def normalize_answer_scaffold(line: str) -> str: str Line with ", with citations" and "with citations" removed, whitespace normalized. """ - normalized = line.replace(", with citations", "").replace("with citations", "") + normalized = re.sub(r",?\s*with citations,?\s*", "", line) return collapse_whitespace(normalized) diff --git a/tests/unit/ai4rag/assets_generator/test_pattern_builder.py b/tests/unit/ai4rag/assets_generator/test_pattern_builder.py index 6154052..e664ba7 100644 --- a/tests/unit/ai4rag/assets_generator/test_pattern_builder.py +++ b/tests/unit/ai4rag/assets_generator/test_pattern_builder.py @@ -109,7 +109,7 @@ def test_adds_responses_template(self): ] assert rt["max_output_tokens"] == 1024 assert rt["temperature"] == 0.7 - assert rt["tool_choice"] == {"mode": "required", "tools": [{}], "type": "file_search"} + assert rt["tool_choice"] == {"type": "file_search"} assert len(rt["tools"]) == 1 assert rt["tools"][0]["type"] == "file_search" assert "test_collection_001" in rt["tools"][0]["vector_store_ids"]