From 93398374656550b5e640fda7afde754a2a7e1fdf Mon Sep 17 00:00:00 2001 From: devatcadam Date: Tue, 25 Nov 2025 19:43:22 +0100 Subject: [PATCH 1/3] asyncio for embedder.py and Update ollama.py for async request to embedder # Description Use await asyncio.gather(*tasks) for async def _async_embed_chunk with semaphore as in LLMEntityRelationExtractor add async def async_embed_query and self.async_client = ollama.AsyncClient(**kwargs) to use self.async_client = ollama.AsyncClient(**kwargs) > **Note** > > Please provide a description of the work completed in this PR below > > ## Type of Change - [x] New feature - [ ] Bug fix - [ ] Breaking change - [ ] Documentation update - [ ] Project configuration change ## Complexity > **Note** > > Please provide an estimated complexity of this PR of either Low, Medium or High > > Complexity: Low ## How Has This Been Tested? - [ ] Unit tests - [ ] E2E tests - [x] Manual tests # Checklist The following requirements should have been met (depending on the changes in the branch): - [x] Documentation has been updated - [ ] Unit tests have been updated - [ ] E2E tests have been updated - [x] Examples have been updated - [ ] New files have copyright header - [x] CLA (https://neo4j.com/developer/cla/) has been signed - [ ] CHANGELOG.md updated if appropriate --- src/neo4j_graphrag/embeddings/ollama.py | 29 +++++++++- .../experimental/components/embedder.py | 57 +++++++++++++++++-- 2 files changed, 79 insertions(+), 7 deletions(-) diff --git a/src/neo4j_graphrag/embeddings/ollama.py b/src/neo4j_graphrag/embeddings/ollama.py index 88f850963..cfb43dd6b 100644 --- a/src/neo4j_graphrag/embeddings/ollama.py +++ b/src/neo4j_graphrag/embeddings/ollama.py @@ -19,7 +19,7 @@ from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.exceptions import EmbeddingsGenerationError -from neo4j_graphrag.utils.rate_limit import RateLimitHandler, rate_limit_handler +from neo4j_graphrag.utils.rate_limit import RateLimitHandler, rate_limit_handler, async_rate_limit_handler class OllamaEmbeddings(Embedder): @@ -47,6 +47,7 @@ def __init__( super().__init__(rate_limit_handler) self.model = model self.client = ollama.Client(**kwargs) + self.async_client = ollama.AsyncClient(**kwargs) @rate_limit_handler def embed_query(self, text: str, **kwargs: Any) -> list[float]: @@ -73,3 +74,29 @@ def embed_query(self, text: str, **kwargs: Any) -> list[float]: raise EmbeddingsGenerationError("Embedding is not a list of floats.") return embedding + + @async_rate_limit_handler + async def async_embed_query(self, text: str, **kwargs: Any) -> list[float]: + """ + Generate embeddings for a given query using an Ollama text embedding model. + + Args: + text (str): The text to generate an embedding for. + **kwargs (Any): Additional keyword arguments to pass to the Ollama client. + """ + embeddings_response = await self.async_client.embed( + model=self.model, + input=text, + **kwargs, + ) + + if embeddings_response is None or not embeddings_response.embeddings: + raise EmbeddingsGenerationError("Failed to retrieve embeddings.") + + embeddings = embeddings_response.embeddings + # client always returns a sequence of sequences + embedding = embeddings[0] + if not isinstance(embedding, list): + raise EmbeddingsGenerationError("Embedding is not a list of floats.") + + return embedding \ No newline at end of file diff --git a/src/neo4j_graphrag/experimental/components/embedder.py b/src/neo4j_graphrag/experimental/components/embedder.py index f113ecde0..cdf24d766 100644 --- a/src/neo4j_graphrag/experimental/components/embedder.py +++ b/src/neo4j_graphrag/experimental/components/embedder.py @@ -14,6 +14,9 @@ # limitations under the License. from pydantic import validate_call +import asyncio +from typing import Any, List, Optional, Union + from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.experimental.components.types import TextChunk, TextChunks from neo4j_graphrag.experimental.pipeline.component import Component @@ -24,6 +27,7 @@ class TextChunkEmbedder(Component): Args: embedder (Embedder): The embedder to use to create the embeddings. + max_concurrency (int): The maximum number of concurrent embedding requests. Defaults to 5. Example: @@ -34,14 +38,21 @@ class TextChunkEmbedder(Component): from neo4j_graphrag.experimental.pipeline import Pipeline embedder = OpenAIEmbeddings(model="text-embedding-3-large") - chunk_embedder = TextChunkEmbedder(embedder) + chunk_embedder = TextChunkEmbedder(embedder=embedder, max_concurrency=10) pipeline = Pipeline() pipeline.add_component(chunk_embedder, "chunk_embedder") """ - def __init__(self, embedder: Embedder): + def __init__( + self, + *args: Any, + embedder: Embedder, + max_concurrency: int = 5, + **kwargs: Any, + ) -> None: self._embedder = embedder + self._max_concurrency = max_concurrency def _embed_chunk(self, text_chunk: TextChunk) -> TextChunk: """Embed a single text chunk. @@ -62,9 +73,36 @@ def _embed_chunk(self, text_chunk: TextChunk) -> TextChunk: metadata=metadata, uid=text_chunk.uid, ) + + async def _async_embed_chunk( + self, + sem: asyncio.Semaphore, + text_chunk: TextChunk) -> TextChunk: + """Asynchronously embed a single text chunk. + + Args: + text_chunk (TextChunk): The text chunk to embed. + + Returns: + TextChunk: The text chunk with an added "embedding" key in its + metadata containing the embeddings of the text chunk's text. + """ + async with sem: + embedding = await self._embedder.async_embed_query(text_chunk.text) + metadata = text_chunk.metadata if text_chunk.metadata else {} + metadata["embedding"] = embedding + return TextChunk( + text=text_chunk.text, + index=text_chunk.index, + metadata=metadata, + uid=text_chunk.uid, + ) @validate_call - async def run(self, text_chunks: TextChunks) -> TextChunks: + async def run( + self, + text_chunks: TextChunks + ) -> TextChunks: """Embed a list of text chunks. Args: @@ -73,6 +111,13 @@ async def run(self, text_chunks: TextChunks) -> TextChunks: Returns: TextChunks: The input text chunks with each one having an added embedding. """ - return TextChunks( - chunks=[self._embed_chunk(text_chunk) for text_chunk in text_chunks.chunks] - ) + sem = asyncio.Semaphore(self._max_concurrency) + tasks = [ + self._async_embed_chunk( + sem, + text_chunk, + ) + for text_chunk in text_chunks.chunks + ] + text_chunks: TextChunks = list(await asyncio.gather(*tasks)) + return TextChunks(chunks=text_chunks) \ No newline at end of file From ccfb9819b9de62b712269cf3bf9b023ac7545628 Mon Sep 17 00:00:00 2001 From: devatcadam Date: Mon, 8 Dec 2025 13:11:53 +0000 Subject: [PATCH 2/3] As suggested by @stellasia : " I'd suggest to add a default implementation that calls the non async method, so that we do not have to work on the actual async implementation for all embedders right now, we will be able to add them progressively." --- src/neo4j_graphrag/embeddings/base.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/neo4j_graphrag/embeddings/base.py b/src/neo4j_graphrag/embeddings/base.py index 34c0c7b59..c0cf8d077 100644 --- a/src/neo4j_graphrag/embeddings/base.py +++ b/src/neo4j_graphrag/embeddings/base.py @@ -48,3 +48,16 @@ def embed_query(self, text: str) -> list[float]: Returns: list[float]: A vector embedding. """ + + def aembed_query(self, text: str) -> list[float]: + """Asynchronously embed query text. + Call embed_query by default as suggested by @stellasia. + Implementation for all embedder will be added progressively. + + Args: + text (str): Text to convert to vector embedding + + Returns: + list[float]: A vector embedding. + """ + return self.embed_query(text) \ No newline at end of file From 1c408f0db278a6b455fac2bf8f5008cf83eed5f1 Mon Sep 17 00:00:00 2001 From: devatcadam Date: Sat, 13 Dec 2025 15:29:23 +0000 Subject: [PATCH 3/3] async def --- src/neo4j_graphrag/embeddings/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neo4j_graphrag/embeddings/base.py b/src/neo4j_graphrag/embeddings/base.py index c0cf8d077..550bc28c9 100644 --- a/src/neo4j_graphrag/embeddings/base.py +++ b/src/neo4j_graphrag/embeddings/base.py @@ -49,7 +49,7 @@ def embed_query(self, text: str) -> list[float]: list[float]: A vector embedding. """ - def aembed_query(self, text: str) -> list[float]: + async def aembed_query(self, text: str) -> list[float]: """Asynchronously embed query text. Call embed_query by default as suggested by @stellasia. Implementation for all embedder will be added progressively.