diff --git a/docs/examples_notebooks/api_overview.ipynb b/docs/examples_notebooks/api_overview.ipynb index 2a0c0f15de..abcd7832fc 100644 --- a/docs/examples_notebooks/api_overview.ipynb +++ b/docs/examples_notebooks/api_overview.ipynb @@ -28,10 +28,11 @@ "from pathlib import Path\n", "from pprint import pprint\n", "\n", - "import graphrag.api as api\n", "import pandas as pd\n", "from graphrag.config.load_config import load_config\n", - "from graphrag.index.typing.pipeline_run_result import PipelineRunResult" + "from graphrag.index.typing.pipeline_run_result import PipelineRunResult\n", + "\n", + "import graphrag.api as api" ] }, { diff --git a/docs/examples_notebooks/input_documents.ipynb b/docs/examples_notebooks/input_documents.ipynb index 5657770eaf..505c0fe1f3 100644 --- a/docs/examples_notebooks/input_documents.ipynb +++ b/docs/examples_notebooks/input_documents.ipynb @@ -30,10 +30,11 @@ "from pathlib import Path\n", "from pprint import pprint\n", "\n", - "import graphrag.api as api\n", "import pandas as pd\n", "from graphrag.config.load_config import load_config\n", - "from graphrag.index.typing.pipeline_run_result import PipelineRunResult" + "from graphrag.index.typing.pipeline_run_result import PipelineRunResult\n", + "\n", + "import graphrag.api as api" ] }, { diff --git a/packages/graphrag/graphrag/query/context_builder/dynamic_community_selection.py b/packages/graphrag/graphrag/query/context_builder/dynamic_community_selection.py index 904478f738..3981e5c8e5 100644 --- a/packages/graphrag/graphrag/query/context_builder/dynamic_community_selection.py +++ b/packages/graphrag/graphrag/query/context_builder/dynamic_community_selection.py @@ -123,8 +123,10 @@ async def select(self, query: str) -> tuple[list[CommunityReport], dict[str, Any # TODO check why some sub_communities are NOT in report_df if community in self.communities: for child in self.communities[community].children: - if child in self.reports: - communities_to_rate.append(child) + # Convert child to string to match self.reports key type + child_str = str(child) + if child_str in self.reports: + communities_to_rate.append(child_str) else: logger.debug( "dynamic community selection: cannot find community %s in reports", diff --git a/tests/unit/query/context_builder/dynamic_community_selection.py b/tests/unit/query/context_builder/dynamic_community_selection.py new file mode 100644 index 0000000000..ba63f0c774 --- /dev/null +++ b/tests/unit/query/context_builder/dynamic_community_selection.py @@ -0,0 +1,205 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Tests for dynamic community selection with type handling.""" + +from unittest.mock import MagicMock + +from graphrag.data_model.community import Community +from graphrag.data_model.community_report import CommunityReport +from graphrag.query.context_builder.dynamic_community_selection import ( + DynamicCommunitySelection, +) + + +def create_mock_tokenizer() -> MagicMock: + """Create a mock tokenizer.""" + tokenizer = MagicMock() + tokenizer.encode.return_value = [1, 2, 3] + return tokenizer + + +def create_mock_model() -> MagicMock: + """Create a mock chat model.""" + return MagicMock() + + +def test_dynamic_community_selection_handles_int_children(): + """Test that DynamicCommunitySelection correctly handles children IDs as integers. + + This tests the fix for issue #2004 where children IDs could be integers + while self.reports keys are strings, causing child communities to be skipped. + """ + # Create communities with integer children (simulating the bug scenario) + # Note: Even though the type annotation says list[str], actual data may have ints + communities = [ + Community( + id="comm-0", + short_id="0", + title="Root Community", + level="0", + parent="", + children=[1, 2], # type: ignore[list-item] # Integer children - testing bug fix + ), + Community( + id="comm-1", + short_id="1", + title="Child Community 1", + level="1", + parent="0", + children=[], + ), + Community( + id="comm-2", + short_id="2", + title="Child Community 2", + level="1", + parent="0", + children=[], + ), + ] + + # Create community reports with string community_id + reports = [ + CommunityReport( + id="report-0", + short_id="0", + title="Report 0", + community_id="0", + summary="Root community summary", + full_content="Root community full content", + rank=1.0, + ), + CommunityReport( + id="report-1", + short_id="1", + title="Report 1", + community_id="1", + summary="Child 1 summary", + full_content="Child 1 full content", + rank=1.0, + ), + CommunityReport( + id="report-2", + short_id="2", + title="Report 2", + community_id="2", + summary="Child 2 summary", + full_content="Child 2 full content", + rank=1.0, + ), + ] + + model = create_mock_model() + tokenizer = create_mock_tokenizer() + + selector = DynamicCommunitySelection( + community_reports=reports, + communities=communities, + model=model, + tokenizer=tokenizer, + threshold=1, + keep_parent=False, + max_level=2, + ) + + # Verify that reports are keyed by string + assert "0" in selector.reports + assert "1" in selector.reports + assert "2" in selector.reports + + # Verify that communities are keyed by string short_id + assert "0" in selector.communities + assert "1" in selector.communities + assert "2" in selector.communities + + # Verify that the children are properly accessible + # Before the fix, int children would fail the `in self.reports` check + root_community = selector.communities["0"] + for child in root_community.children: + child_id = str(child) + # This should now work with the fix + assert child_id in selector.reports, ( + f"Child {child} (as '{child_id}') should be found in reports" + ) + + +def test_dynamic_community_selection_handles_str_children(): + """Test that DynamicCommunitySelection works correctly with string children IDs.""" + communities = [ + Community( + id="comm-0", + short_id="0", + title="Root Community", + level="0", + parent="", + children=["1", "2"], # String children - expected type + ), + Community( + id="comm-1", + short_id="1", + title="Child Community 1", + level="1", + parent="0", + children=[], + ), + Community( + id="comm-2", + short_id="2", + title="Child Community 2", + level="1", + parent="0", + children=[], + ), + ] + + reports = [ + CommunityReport( + id="report-0", + short_id="0", + title="Report 0", + community_id="0", + summary="Root community summary", + full_content="Root community full content", + rank=1.0, + ), + CommunityReport( + id="report-1", + short_id="1", + title="Report 1", + community_id="1", + summary="Child 1 summary", + full_content="Child 1 full content", + rank=1.0, + ), + CommunityReport( + id="report-2", + short_id="2", + title="Report 2", + community_id="2", + summary="Child 2 summary", + full_content="Child 2 full content", + rank=1.0, + ), + ] + + model = create_mock_model() + tokenizer = create_mock_tokenizer() + + selector = DynamicCommunitySelection( + community_reports=reports, + communities=communities, + model=model, + tokenizer=tokenizer, + threshold=1, + keep_parent=False, + max_level=2, + ) + + # Verify that children can be found in reports + root_community = selector.communities["0"] + for child in root_community.children: + child_id = str(child) + assert child_id in selector.reports, ( + f"Child {child} (as '{child_id}') should be found in reports" + ) diff --git a/tests/verbs/test_create_community_reports.py b/tests/verbs/test_create_community_reports.py index d479120ce2..561f54108b 100644 --- a/tests/verbs/test_create_community_reports.py +++ b/tests/verbs/test_create_community_reports.py @@ -4,15 +4,16 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.data_model.schemas import COMMUNITY_REPORTS_FINAL_COLUMNS -from graphrag.index.operations.summarize_communities.community_reports_extractor import ( - CommunityReportResponse, - FindingModel, -) from graphrag.index.workflows.create_community_reports import ( run_workflow, ) from graphrag.utils.storage import load_table_from_storage +from graphrag.index.operations.summarize_communities.community_reports_extractor import ( + CommunityReportResponse, + FindingModel, +) + from .util import ( DEFAULT_MODEL_CONFIG, compare_outputs, diff --git a/unified-search-app/app/app_logic.py b/unified-search-app/app/app_logic.py index dc64e0e77c..a573b9daa5 100644 --- a/unified-search-app/app/app_logic.py +++ b/unified-search-app/app/app_logic.py @@ -7,7 +7,6 @@ import logging from typing import TYPE_CHECKING -import graphrag.api as api import streamlit as st from knowledge_loader.data_sources.loader import ( create_datasource, @@ -18,6 +17,8 @@ from state.session_variables import SessionVariables from ui.search import display_search_result +import graphrag.api as api + if TYPE_CHECKING: import pandas as pd