Skip to content

Commit 59ab206

Browse files
committedSep 19, 2024·
feat - add the support to return & save the merged graph [useful to investigate all descriptions of nodes & the ones between nodes & edges]
1 parent 3125e0a commit 59ab206

File tree

5 files changed

+34
-20
lines changed

5 files changed

+34
-20
lines changed
 

‎examples/simple-app/app/common.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,13 @@ def save_artifacts(artifacts: IndexerArtifacts, path: Path):
193193
artifacts.text_units.to_parquet(f"{path}/text_units.parquet")
194194
artifacts.communities_reports.to_parquet(f"{path}/communities_reports.parquet")
195195

196-
if artifacts.graph is not None:
197-
with path.joinpath("graph.pickle").open("wb") as fp:
198-
pickle.dump(artifacts.graph, fp)
196+
if artifacts.merged_graph is not None:
197+
with path.joinpath("merged-graph.pickle").open("wb") as fp:
198+
pickle.dump(artifacts.merged_graph, fp)
199+
200+
if artifacts.summarized_graph is not None:
201+
with path.joinpath("summarized-graph.pickle").open("wb") as fp:
202+
pickle.dump(artifacts.summarized_graph, fp)
199203

200204
if artifacts.communities is not None:
201205
with path.joinpath("community_info.pickle").open("wb") as fp:
@@ -208,13 +212,19 @@ def load_artifacts(path: Path) -> IndexerArtifacts:
208212
text_units = pd.read_parquet(f"{path}/text_units.parquet")
209213
communities_reports = pd.read_parquet(f"{path}/communities_reports.parquet")
210214

211-
graph = None
215+
merged_graph = None
216+
summarized_graph = None
212217
communities = None
213218

214-
graph_pickled = path.joinpath("graph.pickle")
215-
if graph_pickled.exists():
216-
with graph_pickled.open("rb") as fp:
217-
graph = pickle.load(fp) # noqa: S301
219+
merged_graph_pickled = path.joinpath("merged-graph.pickle")
220+
if merged_graph_pickled.exists():
221+
with merged_graph_pickled.open("rb") as fp:
222+
merged_graph = pickle.load(fp) # noqa: S301
223+
224+
summarized_graph_pickled = path.joinpath("summarized-graph.pickle")
225+
if summarized_graph_pickled.exists():
226+
with summarized_graph_pickled.open("rb") as fp:
227+
summarized_graph = pickle.load(fp) # noqa: S301
218228

219229
community_info_pickled = path.joinpath("community_info.pickle")
220230
if community_info_pickled.exists():
@@ -226,7 +236,8 @@ def load_artifacts(path: Path) -> IndexerArtifacts:
226236
relationships,
227237
text_units,
228238
communities_reports,
229-
graph=graph,
239+
merged_graph=merged_graph,
240+
summarized_graph=summarized_graph,
230241
communities=communities,
231242
)
232243

‎pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "langchain-graphrag"
3-
version = "0.0.4"
3+
version = "0.0.5"
44
description = "Implementation of GraphRAG (https://arxiv.org/pdf/2404.16130)"
55
authors = [{ name = "Kapil Sachdeva", email = "notan@email.com" }]
66
dependencies = [

‎src/langchain_graphrag/indexing/artifacts.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ class IndexerArtifacts(NamedTuple):
1212
relationships: pd.DataFrame
1313
text_units: pd.DataFrame
1414
communities_reports: pd.DataFrame
15-
graph: nx.Graph | None = None
15+
merged_graph: nx.Graph | None = None
16+
summarized_graph: nx.Graph | None = None
1617
communities: CommunityDetectionResult | None = None
1718

1819
def _entity_info(self, top_k: int) -> None:

‎src/langchain_graphrag/indexing/graph_generation/generator.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ def __init__(
1717
self._graphs_merger = graphs_merger
1818
self._er_description_summarizer = er_description_summarizer
1919

20-
def run(self, text_units: pd.DataFrame) -> nx.Graph:
20+
def run(self, text_units: pd.DataFrame) -> tuple[nx.Graph, nx.Graph]:
2121
er_graphs = self._er_extractor.invoke(text_units)
2222
er_merged_graph = self._graphs_merger(er_graphs)
23-
return self._er_description_summarizer.invoke(er_merged_graph)
23+
er_summarized_graph = self._er_description_summarizer.invoke(er_merged_graph)
24+
return er_merged_graph, er_summarized_graph

‎src/langchain_graphrag/indexing/simple_indexer.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -45,26 +45,26 @@ def run(self, documents: list[Document]) -> IndexerArtifacts:
4545
# Step 1 - Text Unit extraction
4646
df_base_text_units = self._text_unit_extractor.run(documents)
4747

48-
# Step 2 - Generate graph
49-
graph = self._graph_generator.run(df_base_text_units)
48+
# Step 2 - Generate graphs
49+
merged_graph, summarized_graph = self._graph_generator.run(df_base_text_units)
5050

5151
# Step 3 - Detect communities in Graph
52-
community_detection_result = self._community_detector.run(graph)
52+
community_detection_result = self._community_detector.run(summarized_graph)
5353

5454
# Step 4 - Reports for detected Communities (depends on Step 2 & Step 3)
5555
df_communities_reports = self._communities_report_artifacts_generator.run(
5656
community_detection_result,
57-
graph,
57+
summarized_graph,
5858
)
5959

6060
# Step 5 - Entities generation (depends on Step 2 & Step 3)
6161
df_entities = self._entities_artifacts_generator.run(
6262
community_detection_result,
63-
graph,
63+
summarized_graph,
6464
)
6565

6666
# Step 6 - Relationships generation (depends on Step 2)
67-
df_relationships = self._relationships_artifacts_generator.run(graph)
67+
df_relationships = self._relationships_artifacts_generator.run(summarized_graph)
6868

6969
# Step 7 - Text Units generation (depends on Steps 1, 5, 6)
7070
df_text_units = self._text_units_artifacts_generator.run(
@@ -78,6 +78,7 @@ def run(self, documents: list[Document]) -> IndexerArtifacts:
7878
relationships=df_relationships,
7979
text_units=df_text_units,
8080
communities_reports=df_communities_reports,
81-
graph=graph,
81+
summarized_graph=summarized_graph,
82+
merged_graph=merged_graph,
8283
communities=community_detection_result,
8384
)

0 commit comments

Comments
 (0)
Please sign in to comment.