From efcb69a61916521b7ee37709df7e4b82dd8c2e5d Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Sat, 20 Sep 2025 19:58:31 +0200 Subject: [PATCH 01/14] Improve error handling for embedders including rate limit --- src/neo4j_graphrag/embeddings/base.py | 15 +++++++ src/neo4j_graphrag/embeddings/cohere.py | 30 ++++++++++---- src/neo4j_graphrag/embeddings/mistral.py | 12 +++++- src/neo4j_graphrag/embeddings/ollama.py | 12 +++++- src/neo4j_graphrag/embeddings/openai.py | 26 +++++++++--- .../embeddings/sentence_transformers.py | 40 +++++++++++++------ src/neo4j_graphrag/embeddings/vertexai.py | 29 ++++++++++---- src/neo4j_graphrag/llm/rate_limit.py | 6 +-- 8 files changed, 131 insertions(+), 39 deletions(-) diff --git a/src/neo4j_graphrag/embeddings/base.py b/src/neo4j_graphrag/embeddings/base.py index 73d6bc9ac..3f131f66a 100644 --- a/src/neo4j_graphrag/embeddings/base.py +++ b/src/neo4j_graphrag/embeddings/base.py @@ -15,14 +15,29 @@ from __future__ import annotations from abc import ABC, abstractmethod +from typing import Any, Optional + +from neo4j_graphrag.llm.rate_limit import ( + DEFAULT_RATE_LIMIT_HANDLER, + RateLimitHandler, +) class Embedder(ABC): """ Interface for embedding models. An embedder passed into a retriever must implement this interface. + + Args: + rate_limit_handler (Optional[RateLimitHandler]): Handler for rate limiting. Defaults to retry with exponential backoff. """ + def __init__(self, rate_limit_handler: Optional[RateLimitHandler] = None): + if rate_limit_handler is not None: + self._rate_limit_handler = rate_limit_handler + else: + self._rate_limit_handler = DEFAULT_RATE_LIMIT_HANDLER + @abstractmethod def embed_query(self, text: str) -> list[float]: """Embed query text. diff --git a/src/neo4j_graphrag/embeddings/cohere.py b/src/neo4j_graphrag/embeddings/cohere.py index 63906a5e0..d8df69036 100644 --- a/src/neo4j_graphrag/embeddings/cohere.py +++ b/src/neo4j_graphrag/embeddings/cohere.py @@ -14,9 +14,11 @@ # limitations under the License. from __future__ import annotations -from typing import Any +from typing import Any, Optional from neo4j_graphrag.embeddings.base import Embedder +from neo4j_graphrag.exceptions import EmbeddingsGenerationError +from neo4j_graphrag.llm.rate_limit import RateLimitHandler, rate_limit_handler try: import cohere @@ -25,19 +27,31 @@ class CohereEmbeddings(Embedder): - def __init__(self, model: str = "", **kwargs: Any) -> None: + def __init__( + self, + model: str = "", + rate_limit_handler: Optional[RateLimitHandler] = None, + **kwargs: Any, + ) -> None: if cohere is None: raise ImportError( """Could not import cohere python client. Please install it with `pip install "neo4j-graphrag[cohere]"`.""" ) + super().__init__(rate_limit_handler) self.model = model self.client = cohere.Client(**kwargs) + @rate_limit_handler def embed_query(self, text: str, **kwargs: Any) -> list[float]: - response = self.client.embed( - texts=[text], - model=self.model, - **kwargs, - ) - return response.embeddings[0] # type: ignore + try: + response = self.client.embed( + texts=[text], + model=self.model, + **kwargs, + ) + return response.embeddings[0] # type: ignore + except Exception as e: + raise EmbeddingsGenerationError( + f"Failed to generate embedding with Cohere: {e}" + ) from e diff --git a/src/neo4j_graphrag/embeddings/mistral.py b/src/neo4j_graphrag/embeddings/mistral.py index 099430193..b356f3ce0 100644 --- a/src/neo4j_graphrag/embeddings/mistral.py +++ b/src/neo4j_graphrag/embeddings/mistral.py @@ -16,10 +16,11 @@ from __future__ import annotations import os -from typing import Any +from typing import Any, Optional from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.exceptions import EmbeddingsGenerationError +from neo4j_graphrag.llm.rate_limit import RateLimitHandler, rate_limit_handler try: from mistralai import Mistral @@ -36,18 +37,25 @@ class MistralAIEmbeddings(Embedder): model (str): The name of the Mistral AI text embedding model to use. Defaults to "mistral-embed". """ - def __init__(self, model: str = "mistral-embed", **kwargs: Any) -> None: + def __init__( + self, + model: str = "mistral-embed", + rate_limit_handler: Optional[RateLimitHandler] = None, + **kwargs: Any, + ) -> None: if Mistral is None: raise ImportError( """Could not import mistralai. Please install it with `pip install "neo4j-graphrag[mistralai]"`.""" ) + super().__init__(rate_limit_handler) api_key = kwargs.pop("api_key", None) if api_key is None: api_key = os.getenv("MISTRAL_API_KEY", "") self.model = model self.mistral_client = Mistral(api_key=api_key, **kwargs) + @rate_limit_handler def embed_query(self, text: str, **kwargs: Any) -> list[float]: """ Generate embeddings for a given query using a Mistral AI text embedding model. diff --git a/src/neo4j_graphrag/embeddings/ollama.py b/src/neo4j_graphrag/embeddings/ollama.py index 78775ba60..5b818ee2d 100644 --- a/src/neo4j_graphrag/embeddings/ollama.py +++ b/src/neo4j_graphrag/embeddings/ollama.py @@ -15,10 +15,11 @@ from __future__ import annotations -from typing import Any +from typing import Any, Optional from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.exceptions import EmbeddingsGenerationError +from neo4j_graphrag.llm.rate_limit import RateLimitHandler, rate_limit_handler class OllamaEmbeddings(Embedder): @@ -30,7 +31,12 @@ class OllamaEmbeddings(Embedder): model (str): The name of the Mistral AI text embedding model to use. Defaults to "mistral-embed". """ - def __init__(self, model: str, **kwargs: Any) -> None: + def __init__( + self, + model: str, + rate_limit_handler: Optional[RateLimitHandler] = None, + **kwargs: Any, + ) -> None: try: import ollama except ImportError: @@ -38,9 +44,11 @@ def __init__(self, model: str, **kwargs: Any) -> None: """Could not import ollama python client. Please install it with `pip install "neo4j_graphrag[ollama]"`.""" ) + super().__init__(rate_limit_handler) self.model = model self.client = ollama.Client(**kwargs) + @rate_limit_handler def embed_query(self, text: str, **kwargs: Any) -> list[float]: """ Generate embeddings for a given query using an Ollama text embedding model. diff --git a/src/neo4j_graphrag/embeddings/openai.py b/src/neo4j_graphrag/embeddings/openai.py index 4a4d60387..880d8d99e 100644 --- a/src/neo4j_graphrag/embeddings/openai.py +++ b/src/neo4j_graphrag/embeddings/openai.py @@ -16,9 +16,11 @@ from __future__ import annotations import abc -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional from neo4j_graphrag.embeddings.base import Embedder +from neo4j_graphrag.exceptions import EmbeddingsGenerationError +from neo4j_graphrag.llm.rate_limit import RateLimitHandler, rate_limit_handler if TYPE_CHECKING: import openai @@ -31,7 +33,12 @@ class BaseOpenAIEmbeddings(Embedder, abc.ABC): client: openai.OpenAI - def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None: + def __init__( + self, + model: str = "text-embedding-ada-002", + rate_limit_handler: Optional[RateLimitHandler] = None, + **kwargs: Any, + ) -> None: try: import openai except ImportError: @@ -39,6 +46,7 @@ def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None """Could not import openai python client. Please install it with `pip install "neo4j-graphrag[openai]"`.""" ) + super().__init__(rate_limit_handler) self.openai = openai self.model = model self.client = self._initialize_client(**kwargs) @@ -51,6 +59,7 @@ def _initialize_client(self, **kwargs: Any) -> Any: """ pass + @rate_limit_handler def embed_query(self, text: str, **kwargs: Any) -> list[float]: """ Generate embeddings for a given query using an OpenAI text embedding model. @@ -59,9 +68,16 @@ def embed_query(self, text: str, **kwargs: Any) -> list[float]: text (str): The text to generate an embedding for. **kwargs (Any): Additional arguments to pass to the OpenAI embedding generation function. """ - response = self.client.embeddings.create(input=text, model=self.model, **kwargs) - embedding: list[float] = response.data[0].embedding - return embedding + try: + response = self.client.embeddings.create( + input=text, model=self.model, **kwargs + ) + embedding: list[float] = response.data[0].embedding + return embedding + except Exception as e: + raise EmbeddingsGenerationError( + f"Failed to generate embedding with OpenAI: {e}" + ) from e class OpenAIEmbeddings(BaseOpenAIEmbeddings): diff --git a/src/neo4j_graphrag/embeddings/sentence_transformers.py b/src/neo4j_graphrag/embeddings/sentence_transformers.py index f30ed5b7f..b8fad4259 100644 --- a/src/neo4j_graphrag/embeddings/sentence_transformers.py +++ b/src/neo4j_graphrag/embeddings/sentence_transformers.py @@ -13,14 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional from neo4j_graphrag.embeddings.base import Embedder +from neo4j_graphrag.exceptions import EmbeddingsGenerationError +from neo4j_graphrag.llm.rate_limit import RateLimitHandler, rate_limit_handler class SentenceTransformerEmbeddings(Embedder): def __init__( - self, model: str = "all-MiniLM-L6-v2", *args: Any, **kwargs: Any + self, + model: str = "all-MiniLM-L6-v2", + rate_limit_handler: Optional[RateLimitHandler] = None, + *args: Any, + **kwargs: Any, ) -> None: try: import numpy as np @@ -31,17 +37,27 @@ def __init__( """Could not import sentence_transformers python package. Please install it with `pip install "neo4j-graphrag[sentence-transformers]"`.""" ) + super().__init__(rate_limit_handler) self.torch = torch self.np = np self.model = sentence_transformers.SentenceTransformer(model, *args, **kwargs) - def embed_query(self, text: str) -> Any: - result = self.model.encode([text]) - if isinstance(result, self.torch.Tensor) or isinstance(result, self.np.ndarray): - return result.flatten().tolist() - elif isinstance(result, list) and all( - isinstance(x, self.torch.Tensor) for x in result - ): - return [item for tensor in result for item in tensor.flatten().tolist()] - else: - raise ValueError("Unexpected return type from model encoding") + @rate_limit_handler + def embed_query(self, text: str) -> list[float]: + try: + result = self.model.encode([text]) + + if isinstance(result, self.torch.Tensor) or isinstance( + result, self.np.ndarray + ): + return result.flatten().tolist() + elif isinstance(result, list) and all( + isinstance(x, self.torch.Tensor) for x in result + ): + return [item for tensor in result for item in tensor.flatten().tolist()] + else: + raise ValueError("Unexpected return type from model encoding") + except Exception as e: + raise EmbeddingsGenerationError( + "Failed to generate embedding with SentenceTransformer" + ) from e diff --git a/src/neo4j_graphrag/embeddings/vertexai.py b/src/neo4j_graphrag/embeddings/vertexai.py index cfed3868a..20c4c76b7 100644 --- a/src/neo4j_graphrag/embeddings/vertexai.py +++ b/src/neo4j_graphrag/embeddings/vertexai.py @@ -14,14 +14,16 @@ # limitations under the License. from __future__ import annotations -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Optional from neo4j_graphrag.embeddings.base import Embedder +from neo4j_graphrag.exceptions import EmbeddingsGenerationError +from neo4j_graphrag.llm.rate_limit import RateLimitHandler, rate_limit_handler try: from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel except (ImportError, AttributeError): - TextEmbeddingModel = TextEmbeddingInput = None # type: ignore[misc, assignment] + TextEmbeddingModel = TextEmbeddingInput = None if TYPE_CHECKING: @@ -37,14 +39,20 @@ class VertexAIEmbeddings(Embedder): model (str): The name of the Vertex AI text embedding model to use. Defaults to "text-embedding-004". """ - def __init__(self, model: str = "text-embedding-004") -> None: + def __init__( + self, + model: str = "text-embedding-004", + rate_limit_handler: Optional[RateLimitHandler] = None, + ) -> None: if TextEmbeddingModel is None: raise ImportError( """Could not import Vertex AI Python client. Please install it with `pip install "neo4j-graphrag[google]"`.""" ) + super().__init__(rate_limit_handler) self.model = TextEmbeddingModel.from_pretrained(model) + @rate_limit_handler def embed_query( self, text: str, task_type: str = "RETRIEVAL_QUERY", **kwargs: Any ) -> list[float]: @@ -56,7 +64,14 @@ def embed_query( task_type (str): The type of the text embedding task. Defaults to "RETRIEVAL_QUERY". See https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#tasktype for a full list. **kwargs (Any): Additional keyword arguments to pass to the Vertex AI client's get_embeddings method. """ - # type annotation needed for mypy - inputs: list[str | TextEmbeddingInput] = [TextEmbeddingInput(text, task_type)] - embeddings = self.model.get_embeddings(inputs, **kwargs) - return embeddings[0].values + try: + # type annotation needed for mypy + inputs: list[str | TextEmbeddingInput] = [ + TextEmbeddingInput(text, task_type) + ] + embeddings = self.model.get_embeddings(inputs, **kwargs) + return list(embeddings[0].values) + except Exception as e: + raise EmbeddingsGenerationError( + f"Failed to generate embedding with VertexAI: {e}" + ) from e diff --git a/src/neo4j_graphrag/llm/rate_limit.py b/src/neo4j_graphrag/llm/rate_limit.py index 098597f78..165de21ee 100644 --- a/src/neo4j_graphrag/llm/rate_limit.py +++ b/src/neo4j_graphrag/llm/rate_limit.py @@ -147,7 +147,7 @@ def handle_async(self, func: AF) -> AF: def is_rate_limit_error(exception: Exception) -> bool: - """Check if an exception is a rate limit error from any LLM provider. + """Check if an exception is a rate limit error from any LLM provider or embedder. Args: exception: The exception to check. @@ -158,8 +158,8 @@ def is_rate_limit_error(exception: Exception) -> bool: error_type = type(exception).__name__.lower() exception_str = str(exception).lower() - # For LLMGenerationError (which wraps all provider errors), check provider-specific patterns - if error_type == "llmgenerationerror": + # For LLMGenerationError or EmbeddingsGenerationError (which wrap all provider errors), check provider-specific patterns + if error_type in ["llmgenerationerror", "embeddingsgenerationerror"]: # Check for various rate limit patterns from different providers rate_limit_patterns = [ "error code: 429", # Azure OpenAI From 8e2190ffaaacb469b7ff4b08a9683515ef0654c5 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Sat, 20 Sep 2025 19:58:54 +0200 Subject: [PATCH 02/14] Update unit tests --- tests/unit/embeddings/test_cohere_embedder.py | 11 +++++++++++ tests/unit/embeddings/test_openai_embedder.py | 16 ++++++++++++++++ .../embeddings/test_sentence_transformers.py | 16 ++++++++++++++++ tests/unit/embeddings/test_vertexai_embedder.py | 13 +++++++++++++ 4 files changed, 56 insertions(+) diff --git a/tests/unit/embeddings/test_cohere_embedder.py b/tests/unit/embeddings/test_cohere_embedder.py index 244e90c52..c962d7450 100644 --- a/tests/unit/embeddings/test_cohere_embedder.py +++ b/tests/unit/embeddings/test_cohere_embedder.py @@ -16,6 +16,7 @@ import pytest from neo4j_graphrag.embeddings.cohere import CohereEmbeddings +from neo4j_graphrag.exceptions import EmbeddingsGenerationError @patch("neo4j_graphrag.embeddings.cohere.cohere", None) @@ -32,3 +33,13 @@ def test_cohere_embedder_happy_path(mock_cohere: Mock) -> None: embedder = CohereEmbeddings() res = embedder.embed_query("my text") assert res == [1.0, 2.0] + + +@patch("neo4j_graphrag.embeddings.cohere.cohere") +def test_cohere_embedder_error_handling(mock_cohere: Mock) -> None: + mock_cohere.Client.return_value.embed.side_effect = Exception("API Error") + embedder = CohereEmbeddings() + with pytest.raises( + EmbeddingsGenerationError, match="Failed to generate embedding with Cohere" + ): + embedder.embed_query("my text") diff --git a/tests/unit/embeddings/test_openai_embedder.py b/tests/unit/embeddings/test_openai_embedder.py index a1b940f04..046a9519f 100644 --- a/tests/unit/embeddings/test_openai_embedder.py +++ b/tests/unit/embeddings/test_openai_embedder.py @@ -20,6 +20,7 @@ AzureOpenAIEmbeddings, OpenAIEmbeddings, ) +from neo4j_graphrag.exceptions import EmbeddingsGenerationError def get_mock_openai() -> MagicMock: @@ -92,3 +93,18 @@ def test_azure_openai_embedder_does_not_call_openai_client() -> None: api_key="my_key", api_version="2023-05-15", ) + + +@patch("builtins.__import__") +def test_openai_embedder_error_handling(mock_import: Mock) -> None: + mock_openai = get_mock_openai() + mock_import.return_value = mock_openai + + mock_openai.OpenAI.return_value.embeddings.create.side_effect = Exception( + "API Error" + ) + embedder = OpenAIEmbeddings(api_key="my key") + with pytest.raises( + EmbeddingsGenerationError, match="Failed to generate embedding with OpenAI" + ): + embedder.embed_query("my text") diff --git a/tests/unit/embeddings/test_sentence_transformers.py b/tests/unit/embeddings/test_sentence_transformers.py index 197095e49..7a8db0817 100644 --- a/tests/unit/embeddings/test_sentence_transformers.py +++ b/tests/unit/embeddings/test_sentence_transformers.py @@ -7,6 +7,7 @@ from neo4j_graphrag.embeddings.sentence_transformers import ( SentenceTransformerEmbeddings, ) +from neo4j_graphrag.exceptions import EmbeddingsGenerationError def get_mock_sentence_transformers() -> MagicMock: @@ -55,3 +56,18 @@ def test_embed_query(mock_import: Mock) -> None: def test_import_error(mock_import: Mock) -> None: with pytest.raises(ImportError): SentenceTransformerEmbeddings() + + +@patch("builtins.__import__") +def test_embed_query_error_handling(mock_import: Mock) -> None: + MockSentenceTransformer = get_mock_sentence_transformers() + mock_import.return_value = MockSentenceTransformer + mock_model = MockSentenceTransformer.SentenceTransformer.return_value + mock_model.encode.side_effect = Exception("Model error") + + instance = SentenceTransformerEmbeddings() + with pytest.raises( + EmbeddingsGenerationError, + match="Failed to generate embedding with SentenceTransformer", + ): + instance.embed_query("test query") diff --git a/tests/unit/embeddings/test_vertexai_embedder.py b/tests/unit/embeddings/test_vertexai_embedder.py index 018230a67..960035466 100644 --- a/tests/unit/embeddings/test_vertexai_embedder.py +++ b/tests/unit/embeddings/test_vertexai_embedder.py @@ -16,6 +16,7 @@ import pytest from neo4j_graphrag.embeddings.vertexai import VertexAIEmbeddings +from neo4j_graphrag.exceptions import EmbeddingsGenerationError @patch("neo4j_graphrag.embeddings.vertexai.TextEmbeddingModel", None) @@ -33,3 +34,15 @@ def test_vertexai_embedder_happy_path(mock_vertexai: Mock) -> None: res = embedder.embed_query("my text") assert isinstance(res, list) assert res == [1.0, 2.0] + + +@patch("neo4j_graphrag.embeddings.vertexai.TextEmbeddingModel") +def test_vertexai_embedder_error_handling(mock_vertexai: Mock) -> None: + mock_vertexai.from_pretrained.return_value.get_embeddings.side_effect = Exception( + "API Error" + ) + embedder = VertexAIEmbeddings() + with pytest.raises( + EmbeddingsGenerationError, match="Failed to generate embedding with VertexAI" + ): + embedder.embed_query("my text") From 1fff12b5a8d6577b72d6b0558fccbf37cb9c562d Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Sat, 20 Sep 2025 19:59:08 +0200 Subject: [PATCH 03/14] Update changelog and docs --- CHANGELOG.md | 4 ++++ docs/source/user_guide_rag.rst | 31 +++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 35a2154fd..1c0b6c9b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ ## Next +### Added + +- Added automatic rate limiting with retry logic and exponential backoff for all Embedding providers using tenacity. The `RateLimitHandler` interface allows for custom rate limiting strategies, including the ability to disable rate limiting entirely. + ## 1.10.0 ### Added diff --git a/docs/source/user_guide_rag.rst b/docs/source/user_guide_rag.rst index 0a6afed34..f381b963f 100644 --- a/docs/source/user_guide_rag.rst +++ b/docs/source/user_guide_rag.rst @@ -528,6 +528,37 @@ The `OpenAIEmbeddings` was illustrated previously. Here is how to use the `Sente If another embedder is desired, a custom embedder can be created, using the `Embedder` interface. +Embedder Rate Limiting +---------------------- + +All embedder implementations include automatic rate limiting that uses retry logic with exponential backoff by default, similar to LLM implementations. This feature helps handle API rate limits from embedding providers gracefully. + +.. code:: python + + from neo4j_graphrag.embeddings import OpenAIEmbeddings + from neo4j_graphrag.llm.rate_limit import RetryRateLimitHandler, NoOpRateLimitHandler + + # Default rate limiting (automatically enabled) + embedder = OpenAIEmbeddings(model="text-embedding-3-large") + + # Custom rate limiting configuration + embedder = OpenAIEmbeddings( + model="text-embedding-3-large", + rate_limit_handler=RetryRateLimitHandler( + max_attempts=5, + min_wait=2.0, + max_wait=120.0 + ) + ) + + # Disable rate limiting + embedder = OpenAIEmbeddings( + model="text-embedding-3-large", + rate_limit_handler=NoOpRateLimitHandler() + ) + +The rate limiting configuration works the same way as for LLMs. See the :ref:`Rate Limit Handling ` section above for more details on customization options. + Other Vector Retriever Configuration ---------------------------------------- From 54c28a45651cc15c5844f273696b732eb65c1912 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Sat, 20 Sep 2025 20:12:52 +0200 Subject: [PATCH 04/14] Ruff --- src/neo4j_graphrag/embeddings/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neo4j_graphrag/embeddings/base.py b/src/neo4j_graphrag/embeddings/base.py index 3f131f66a..cd5317fde 100644 --- a/src/neo4j_graphrag/embeddings/base.py +++ b/src/neo4j_graphrag/embeddings/base.py @@ -15,7 +15,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Optional from neo4j_graphrag.llm.rate_limit import ( DEFAULT_RATE_LIMIT_HANDLER, From c9b4798dd32e29d64138ba2e82068128f07a6bbf Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Sat, 20 Sep 2025 23:33:48 +0200 Subject: [PATCH 05/14] Fix mypy --- src/neo4j_graphrag/embeddings/vertexai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neo4j_graphrag/embeddings/vertexai.py b/src/neo4j_graphrag/embeddings/vertexai.py index 20c4c76b7..0bc30c288 100644 --- a/src/neo4j_graphrag/embeddings/vertexai.py +++ b/src/neo4j_graphrag/embeddings/vertexai.py @@ -23,7 +23,7 @@ try: from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel except (ImportError, AttributeError): - TextEmbeddingModel = TextEmbeddingInput = None + TextEmbeddingModel = TextEmbeddingInput = None # type: ignore[misc, assignment] if TYPE_CHECKING: From 33c781056c82dcf17355061e3608705fc6927dc3 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Sat, 20 Sep 2025 23:50:18 +0200 Subject: [PATCH 06/14] Fix more mypy issues --- src/neo4j_graphrag/embeddings/sentence_transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neo4j_graphrag/embeddings/sentence_transformers.py b/src/neo4j_graphrag/embeddings/sentence_transformers.py index b8fad4259..a50fac19c 100644 --- a/src/neo4j_graphrag/embeddings/sentence_transformers.py +++ b/src/neo4j_graphrag/embeddings/sentence_transformers.py @@ -43,7 +43,7 @@ def __init__( self.model = sentence_transformers.SentenceTransformer(model, *args, **kwargs) @rate_limit_handler - def embed_query(self, text: str) -> list[float]: + def embed_query(self, text: str) -> Any: try: result = self.model.encode([text]) From 43c18d4578385df3f725efb678be8f189ee9845e Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Fri, 26 Sep 2025 17:43:10 +0200 Subject: [PATCH 07/14] Improve error handling for Mistral and Sentence Transformers --- src/neo4j_graphrag/embeddings/mistral.py | 12 +++++++++--- .../embeddings/sentence_transformers.py | 2 +- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/neo4j_graphrag/embeddings/mistral.py b/src/neo4j_graphrag/embeddings/mistral.py index b356f3ce0..f362ecb26 100644 --- a/src/neo4j_graphrag/embeddings/mistral.py +++ b/src/neo4j_graphrag/embeddings/mistral.py @@ -64,9 +64,15 @@ def embed_query(self, text: str, **kwargs: Any) -> list[float]: text (str): The text to generate an embedding for. **kwargs (Any): Additional keyword arguments to pass to the Mistral AI client. """ - embeddings_batch_response = self.mistral_client.embeddings.create( - model=self.model, inputs=[text], **kwargs - ) + try: + embeddings_batch_response = self.mistral_client.embeddings.create( + model=self.model, inputs=[text], **kwargs + ) + except Exception as e: + raise EmbeddingsGenerationError( + f"Failed to generate embedding with MistralAI: {e}" + ) from e + if embeddings_batch_response is None or not embeddings_batch_response.data: raise EmbeddingsGenerationError("Failed to retrieve embeddings.") diff --git a/src/neo4j_graphrag/embeddings/sentence_transformers.py b/src/neo4j_graphrag/embeddings/sentence_transformers.py index a50fac19c..5368800b3 100644 --- a/src/neo4j_graphrag/embeddings/sentence_transformers.py +++ b/src/neo4j_graphrag/embeddings/sentence_transformers.py @@ -59,5 +59,5 @@ def embed_query(self, text: str) -> Any: raise ValueError("Unexpected return type from model encoding") except Exception as e: raise EmbeddingsGenerationError( - "Failed to generate embedding with SentenceTransformer" + f"Failed to generate embedding with SentenceTransformer: {e}" ) from e From 4bc0ef51184da834bc0391d7d8e3b9904323ae73 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Fri, 26 Sep 2025 17:43:26 +0200 Subject: [PATCH 08/14] Improve unit tests --- tests/unit/embeddings/test_cohere_embedder.py | 51 ++++++++++++- .../embeddings/test_mistralai_embedder.py | 71 +++++++++++++++++++ tests/unit/embeddings/test_openai_embedder.py | 61 ++++++++++++++-- .../embeddings/test_sentence_transformers.py | 55 +++++++++++++- .../unit/embeddings/test_vertexai_embedder.py | 54 ++++++++++++-- 5 files changed, 281 insertions(+), 11 deletions(-) diff --git a/tests/unit/embeddings/test_cohere_embedder.py b/tests/unit/embeddings/test_cohere_embedder.py index c962d7450..af87f4a24 100644 --- a/tests/unit/embeddings/test_cohere_embedder.py +++ b/tests/unit/embeddings/test_cohere_embedder.py @@ -15,6 +15,7 @@ from unittest.mock import MagicMock, Mock, patch import pytest +from tenacity import RetryError from neo4j_graphrag.embeddings.cohere import CohereEmbeddings from neo4j_graphrag.exceptions import EmbeddingsGenerationError @@ -36,10 +37,56 @@ def test_cohere_embedder_happy_path(mock_cohere: Mock) -> None: @patch("neo4j_graphrag.embeddings.cohere.cohere") -def test_cohere_embedder_error_handling(mock_cohere: Mock) -> None: - mock_cohere.Client.return_value.embed.side_effect = Exception("API Error") +def test_cohere_embedder_non_retryable_error_handling(mock_cohere: Mock) -> None: + """Test that non-retryable errors fail immediately without retries.""" + mock_embeddings = mock_cohere.Client.return_value.embed + mock_embeddings.side_effect = Exception("API Error") embedder = CohereEmbeddings() with pytest.raises( EmbeddingsGenerationError, match="Failed to generate embedding with Cohere" ): embedder.embed_query("my text") + + # Verify the API was called only once (no retries for non-rate-limit errors) + assert mock_embeddings.call_count == 1 + + +@patch("neo4j_graphrag.embeddings.cohere.cohere") +def test_cohere_embedder_rate_limit_error_retries(mock_cohere: Mock) -> None: + """Test that rate limit errors are retried the expected number of times.""" + # Rate limit error that should trigger retries (matches "too many requests" pattern) + # Create separate exception instances for each retry attempt + mock_embeddings = mock_cohere.Client.return_value.embed + mock_embeddings.side_effect = [ + Exception("too many requests - please try again later"), + Exception("too many requests - please try again later"), + Exception("too many requests - please try again later"), + ] + embedder = CohereEmbeddings() + + # After exhausting retries, tenacity raises RetryError + with pytest.raises(RetryError): + embedder.embed_query("my text") + + # Verify the API was called 3 times (default max_attempts for RetryRateLimitHandler) + assert mock_cohere.Client.return_value.embed.call_count == 3 + + +@patch("neo4j_graphrag.embeddings.cohere.cohere") +def test_cohere_embedder_rate_limit_error_eventual_success(mock_cohere: Mock) -> None: + """Test that rate limit errors eventually succeed after retries.""" + # First two calls fail with rate limit, third succeeds + mock_embeddings = mock_cohere.Client.return_value.embed + mock_embeddings.side_effect = [ + Exception("too many requests - please try again later"), + Exception("too many requests - please try again later"), + MagicMock(embeddings=[[1.0, 2.0]]), + ] + embedder = CohereEmbeddings() + + result = embedder.embed_query("my text") + + # Verify successful result + assert result == [1.0, 2.0] + # Verify the API was called 3 times before succeeding + assert mock_embeddings.call_count == 3 diff --git a/tests/unit/embeddings/test_mistralai_embedder.py b/tests/unit/embeddings/test_mistralai_embedder.py index f962f7c37..f2fd4cbb4 100644 --- a/tests/unit/embeddings/test_mistralai_embedder.py +++ b/tests/unit/embeddings/test_mistralai_embedder.py @@ -15,7 +15,9 @@ from unittest.mock import MagicMock, Mock, patch import pytest +from tenacity import RetryError from neo4j_graphrag.embeddings import MistralAIEmbeddings +from neo4j_graphrag.exceptions import EmbeddingsGenerationError @patch("neo4j_graphrag.embeddings.mistral.Mistral", None) @@ -72,3 +74,72 @@ def test_mistralai_embedder_api_key_from_env( mock_getenv.assert_called_with("MISTRAL_API_KEY", "") mock_mistral.assert_called_with(api_key="env_api_key") + + +@patch("neo4j_graphrag.embeddings.mistral.Mistral") +def test_mistralai_embedder_non_retryable_error_handling(mock_mistral: Mock) -> None: + """Test that non-retryable errors fail immediately without retries.""" + mock_mistral_instance = mock_mistral.return_value + mock_embeddings = mock_mistral_instance.embeddings.create + mock_embeddings.side_effect = Exception("API Error") + + embedder = MistralAIEmbeddings() + # MistralAI now wraps exceptions, so we expect EmbeddingsGenerationError + with pytest.raises( + EmbeddingsGenerationError, match="Failed to generate embedding with MistralAI" + ): + embedder.embed_query("my text") + + # Verify the API was called only once (no retries for non-rate-limit errors) + assert mock_embeddings.call_count == 1 + + +@patch("neo4j_graphrag.embeddings.mistral.Mistral") +def test_mistralai_embedder_rate_limit_error_retries(mock_mistral: Mock) -> None: + """Test that rate limit errors are retried the expected number of times.""" + mock_mistral_instance = mock_mistral.return_value + + # Rate limit error that should trigger retries (matches "too many requests" pattern) + # Create separate exception instances for each retry attempt + mock_embeddings = mock_mistral_instance.embeddings.create + mock_embeddings.side_effect = [ + Exception("too many requests - rate limit exceeded"), + Exception("too many requests - rate limit exceeded"), + Exception("too many requests - rate limit exceeded"), + ] + + embedder = MistralAIEmbeddings() + + # After exhausting retries, tenacity raises RetryError + with pytest.raises(RetryError): + embedder.embed_query("my text") + + # Verify the API was called 3 times (default max_attempts for RetryRateLimitHandler) + assert mock_embeddings.call_count == 3 + + +@patch("neo4j_graphrag.embeddings.mistral.Mistral") +def test_mistralai_embedder_rate_limit_error_eventual_success( + mock_mistral: Mock, +) -> None: + """Test that rate limit errors eventually succeed after retries.""" + mock_mistral_instance = mock_mistral.return_value + + # First two calls fail with rate limit, third succeeds + embeddings_batch_response_mock = MagicMock() + embeddings_batch_response_mock.data = [MagicMock(embedding=[1.0, 2.0])] + + mock_embeddings = mock_mistral_instance.embeddings.create + mock_embeddings.side_effect = [ + Exception("too many requests - rate limit exceeded"), + Exception("too many requests - rate limit exceeded"), + embeddings_batch_response_mock, + ] + + embedder = MistralAIEmbeddings() + result = embedder.embed_query("my text") + + # Verify successful result + assert result == [1.0, 2.0] + # Verify the API was called 3 times before succeeding + assert mock_embeddings.call_count == 3 diff --git a/tests/unit/embeddings/test_openai_embedder.py b/tests/unit/embeddings/test_openai_embedder.py index 046a9519f..cdfc3d6c5 100644 --- a/tests/unit/embeddings/test_openai_embedder.py +++ b/tests/unit/embeddings/test_openai_embedder.py @@ -16,6 +16,7 @@ import openai import pytest +from tenacity import RetryError from neo4j_graphrag.embeddings.openai import ( AzureOpenAIEmbeddings, OpenAIEmbeddings, @@ -96,15 +97,67 @@ def test_azure_openai_embedder_does_not_call_openai_client() -> None: @patch("builtins.__import__") -def test_openai_embedder_error_handling(mock_import: Mock) -> None: +def test_openai_embedder_non_retryable_error_handling(mock_import: Mock) -> None: + """Test that non-retryable errors fail immediately without retries.""" mock_openai = get_mock_openai() mock_import.return_value = mock_openai - mock_openai.OpenAI.return_value.embeddings.create.side_effect = Exception( - "API Error" - ) + # Generic API error that doesn't match rate limit patterns - should not be retried + mock_embeddings = mock_openai.OpenAI.return_value.embeddings.create + mock_embeddings.side_effect = Exception("API Error") embedder = OpenAIEmbeddings(api_key="my key") + with pytest.raises( EmbeddingsGenerationError, match="Failed to generate embedding with OpenAI" ): embedder.embed_query("my text") + + # Verify the API was called only once (no retries for non-rate-limit errors) + assert mock_embeddings.call_count == 1 + + +@patch("builtins.__import__") +def test_openai_embedder_rate_limit_error_retries(mock_import: Mock) -> None: + """Test that rate limit errors are retried the expected number of times.""" + mock_openai = get_mock_openai() + mock_import.return_value = mock_openai + + # Rate limit error that should trigger retries (matches "429" pattern) + # Create separate exception instances for each retry attempt + mock_embeddings = mock_openai.OpenAI.return_value.embeddings.create + mock_embeddings.side_effect = [ + Exception("Error code: 429 - Too many requests"), + Exception("Error code: 429 - Too many requests"), + Exception("Error code: 429 - Too many requests"), + ] + embedder = OpenAIEmbeddings(api_key="my key") + + # After exhausting retries, tenacity raises RetryError + with pytest.raises(RetryError): + embedder.embed_query("my text") + + # Verify the API was called 3 times (default max_attempts for RetryRateLimitHandler) + assert mock_embeddings.call_count == 3 + + +@patch("builtins.__import__") +def test_openai_embedder_rate_limit_error_eventual_success(mock_import: Mock) -> None: + """Test that rate limit errors eventually succeed after retries.""" + mock_openai = get_mock_openai() + mock_import.return_value = mock_openai + + # First two calls fail with rate limit, third succeeds + mock_embeddings = mock_openai.OpenAI.return_value.embeddings.create + mock_embeddings.side_effect = [ + Exception("Error code: 429 - Too many requests"), + Exception("Error code: 429 - Too many requests"), + MagicMock(data=[MagicMock(embedding=[1.0, 2.0])]), + ] + embedder = OpenAIEmbeddings(api_key="my key") + + result = embedder.embed_query("my text") + + # Verify successful result + assert result == [1.0, 2.0] + # Verify the API was called 3 times before succeeding + assert mock_embeddings.call_count == 3 diff --git a/tests/unit/embeddings/test_sentence_transformers.py b/tests/unit/embeddings/test_sentence_transformers.py index 7a8db0817..f9c8f36bc 100644 --- a/tests/unit/embeddings/test_sentence_transformers.py +++ b/tests/unit/embeddings/test_sentence_transformers.py @@ -3,6 +3,7 @@ import numpy as np import pytest import torch +from tenacity import RetryError from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.embeddings.sentence_transformers import ( SentenceTransformerEmbeddings, @@ -59,7 +60,8 @@ def test_import_error(mock_import: Mock) -> None: @patch("builtins.__import__") -def test_embed_query_error_handling(mock_import: Mock) -> None: +def test_embed_query_non_retryable_error_handling(mock_import: Mock) -> None: + """Test that non-retryable errors fail immediately without retries.""" MockSentenceTransformer = get_mock_sentence_transformers() mock_import.return_value = MockSentenceTransformer mock_model = MockSentenceTransformer.SentenceTransformer.return_value @@ -71,3 +73,54 @@ def test_embed_query_error_handling(mock_import: Mock) -> None: match="Failed to generate embedding with SentenceTransformer", ): instance.embed_query("test query") + + # Verify the model was called only once (no retries for non-rate-limit errors) + assert mock_model.encode.call_count == 1 + + +@patch("builtins.__import__") +def test_embed_query_rate_limit_error_retries(mock_import: Mock) -> None: + """Test that rate limit errors are retried the expected number of times.""" + MockSentenceTransformer = get_mock_sentence_transformers() + mock_import.return_value = MockSentenceTransformer + mock_model = MockSentenceTransformer.SentenceTransformer.return_value + + # Rate limit error that should trigger retries (matches "too many requests" pattern) + # Create separate exception instances for each retry attempt + mock_model.encode.side_effect = [ + Exception("too many requests - please wait"), + Exception("too many requests - please wait"), + Exception("too many requests - please wait"), + ] + + instance = SentenceTransformerEmbeddings() + + # After exhausting retries, tenacity raises RetryError (since retries should work) + with pytest.raises(RetryError): + instance.embed_query("test query") + + # Verify the model was called 3 times (default max_attempts for RetryRateLimitHandler) + assert mock_model.encode.call_count == 3 + + +@patch("builtins.__import__") +def test_embed_query_rate_limit_error_eventual_success(mock_import: Mock) -> None: + """Test that rate limit errors eventually succeed after retries.""" + MockSentenceTransformer = get_mock_sentence_transformers() + mock_import.return_value = MockSentenceTransformer + mock_model = MockSentenceTransformer.SentenceTransformer.return_value + + # First two calls fail with rate limit, third succeeds + mock_model.encode.side_effect = [ + Exception("too many requests - please wait"), + Exception("too many requests - please wait"), + np.array([[0.1, 0.2, 0.3]]), + ] + + instance = SentenceTransformerEmbeddings() + result = instance.embed_query("test query") + + # Verify successful result + assert result == [0.1, 0.2, 0.3] + # Verify the model was called 3 times before succeeding + assert mock_model.encode.call_count == 3 diff --git a/tests/unit/embeddings/test_vertexai_embedder.py b/tests/unit/embeddings/test_vertexai_embedder.py index 960035466..790d7e402 100644 --- a/tests/unit/embeddings/test_vertexai_embedder.py +++ b/tests/unit/embeddings/test_vertexai_embedder.py @@ -15,6 +15,7 @@ from unittest.mock import MagicMock, Mock, patch import pytest +from tenacity import RetryError from neo4j_graphrag.embeddings.vertexai import VertexAIEmbeddings from neo4j_graphrag.exceptions import EmbeddingsGenerationError @@ -37,12 +38,57 @@ def test_vertexai_embedder_happy_path(mock_vertexai: Mock) -> None: @patch("neo4j_graphrag.embeddings.vertexai.TextEmbeddingModel") -def test_vertexai_embedder_error_handling(mock_vertexai: Mock) -> None: - mock_vertexai.from_pretrained.return_value.get_embeddings.side_effect = Exception( - "API Error" - ) +def test_vertexai_embedder_non_retryable_error_handling(mock_vertexai: Mock) -> None: + """Test that non-retryable errors fail immediately without retries.""" + mock_embeddings = mock_vertexai.from_pretrained.return_value.get_embeddings + mock_embeddings.side_effect = Exception("API Error") embedder = VertexAIEmbeddings() with pytest.raises( EmbeddingsGenerationError, match="Failed to generate embedding with VertexAI" ): embedder.embed_query("my text") + + # Verify the API was called only once (no retries for non-rate-limit errors) + assert mock_embeddings.call_count == 1 + + +@patch("neo4j_graphrag.embeddings.vertexai.TextEmbeddingModel") +def test_vertexai_embedder_rate_limit_error_retries(mock_vertexai: Mock) -> None: + """Test that rate limit errors are retried the expected number of times.""" + # Rate limit error that should trigger retries (matches "resource exhausted" pattern) + mock_embeddings = mock_vertexai.from_pretrained.return_value.get_embeddings + mock_embeddings.side_effect = [ + Exception("resource exhausted - quota exceeded"), + Exception("resource exhausted - quota exceeded"), + Exception("resource exhausted - quota exceeded"), + ] + embedder = VertexAIEmbeddings() + + # After exhausting retries, tenacity raises RetryError + with pytest.raises(RetryError): + embedder.embed_query("my text") + + # Verify the API was called 3 times (default max_attempts for RetryRateLimitHandler) + assert mock_embeddings.call_count == 3 + + +@patch("neo4j_graphrag.embeddings.vertexai.TextEmbeddingModel") +def test_vertexai_embedder_rate_limit_error_eventual_success( + mock_vertexai: Mock, +) -> None: + """Test that rate limit errors eventually succeed after retries.""" + # First two calls fail with rate limit, third succeeds + mock_embeddings = mock_vertexai.from_pretrained.return_value.get_embeddings + mock_embeddings.side_effect = [ + Exception("resource exhausted - quota exceeded"), + Exception("resource exhausted - quota exceeded"), + [MagicMock(values=[1.0, 2.0])], + ] + embedder = VertexAIEmbeddings() + + result = embedder.embed_query("my text") + + # Verify successful result + assert result == [1.0, 2.0] + # Verify the API was called 3 times before succeeding + assert mock_embeddings.call_count == 3 From a245f73dd153e733b699c084154482880e92702f Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Fri, 26 Sep 2025 19:45:46 +0200 Subject: [PATCH 09/14] Move rate limit handler decorator to base class --- .../customize/embeddings/custom_embeddings.py | 2 +- .../hybrid_retrievers/hybrid_cypher_search.py | 2 +- .../retrievers/hybrid_retrievers/hybrid_search.py | 2 +- src/neo4j_graphrag/embeddings/base.py | 15 ++++++++++++++- src/neo4j_graphrag/embeddings/cohere.py | 5 ++--- src/neo4j_graphrag/embeddings/mistral.py | 5 ++--- src/neo4j_graphrag/embeddings/ollama.py | 5 ++--- src/neo4j_graphrag/embeddings/openai.py | 5 ++--- .../embeddings/sentence_transformers.py | 5 ++--- src/neo4j_graphrag/embeddings/vertexai.py | 5 ++--- tests/e2e/conftest.py | 4 ++-- 11 files changed, 31 insertions(+), 24 deletions(-) diff --git a/examples/customize/embeddings/custom_embeddings.py b/examples/customize/embeddings/custom_embeddings.py index 5b15eb0f7..e77127359 100644 --- a/examples/customize/embeddings/custom_embeddings.py +++ b/examples/customize/embeddings/custom_embeddings.py @@ -8,7 +8,7 @@ class CustomEmbeddings(Embedder): def __init__(self, dimension: int = 10, **kwargs: Any): self.dimension = dimension - def embed_query(self, input: str) -> list[float]: + def _embed_query(self, input: str) -> list[float]: return [random.random() for _ in range(self.dimension)] diff --git a/examples/customize/retrievers/hybrid_retrievers/hybrid_cypher_search.py b/examples/customize/retrievers/hybrid_retrievers/hybrid_cypher_search.py index e1d59e379..6268e82fd 100644 --- a/examples/customize/retrievers/hybrid_retrievers/hybrid_cypher_search.py +++ b/examples/customize/retrievers/hybrid_retrievers/hybrid_cypher_search.py @@ -20,7 +20,7 @@ # Create Embedder object class CustomEmbedder(Embedder): - def embed_query(self, text: str) -> list[float]: + def _embed_query(self, text: str) -> list[float]: return [random() for _ in range(DIMENSION)] diff --git a/examples/customize/retrievers/hybrid_retrievers/hybrid_search.py b/examples/customize/retrievers/hybrid_retrievers/hybrid_search.py index 69940596a..b9b9dd792 100644 --- a/examples/customize/retrievers/hybrid_retrievers/hybrid_search.py +++ b/examples/customize/retrievers/hybrid_retrievers/hybrid_search.py @@ -20,7 +20,7 @@ # Create Embedder object class CustomEmbedder(Embedder): - def embed_query(self, text: str) -> list[float]: + def _embed_query(self, text: str) -> list[float]: return [random() for _ in range(DIMENSION)] diff --git a/src/neo4j_graphrag/embeddings/base.py b/src/neo4j_graphrag/embeddings/base.py index cd5317fde..dae738476 100644 --- a/src/neo4j_graphrag/embeddings/base.py +++ b/src/neo4j_graphrag/embeddings/base.py @@ -20,6 +20,7 @@ from neo4j_graphrag.llm.rate_limit import ( DEFAULT_RATE_LIMIT_HANDLER, RateLimitHandler, + rate_limit_handler, ) @@ -38,7 +39,7 @@ def __init__(self, rate_limit_handler: Optional[RateLimitHandler] = None): else: self._rate_limit_handler = DEFAULT_RATE_LIMIT_HANDLER - @abstractmethod + @rate_limit_handler def embed_query(self, text: str) -> list[float]: """Embed query text. @@ -48,3 +49,15 @@ def embed_query(self, text: str) -> list[float]: Returns: list[float]: A vector embedding. """ + return self._embed_query(text) + + @abstractmethod + def _embed_query(self, text: str) -> list[float]: + """Embed query text. + + Args: + text (str): Text to convert to vector embedding + + Returns: + list[float]: A vector embedding. + """ diff --git a/src/neo4j_graphrag/embeddings/cohere.py b/src/neo4j_graphrag/embeddings/cohere.py index d8df69036..e6185843d 100644 --- a/src/neo4j_graphrag/embeddings/cohere.py +++ b/src/neo4j_graphrag/embeddings/cohere.py @@ -18,7 +18,7 @@ from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.exceptions import EmbeddingsGenerationError -from neo4j_graphrag.llm.rate_limit import RateLimitHandler, rate_limit_handler +from neo4j_graphrag.llm.rate_limit import RateLimitHandler try: import cohere @@ -42,8 +42,7 @@ def __init__( self.model = model self.client = cohere.Client(**kwargs) - @rate_limit_handler - def embed_query(self, text: str, **kwargs: Any) -> list[float]: + def _embed_query(self, text: str, **kwargs: Any) -> list[float]: try: response = self.client.embed( texts=[text], diff --git a/src/neo4j_graphrag/embeddings/mistral.py b/src/neo4j_graphrag/embeddings/mistral.py index f362ecb26..768c83c24 100644 --- a/src/neo4j_graphrag/embeddings/mistral.py +++ b/src/neo4j_graphrag/embeddings/mistral.py @@ -20,7 +20,7 @@ from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.exceptions import EmbeddingsGenerationError -from neo4j_graphrag.llm.rate_limit import RateLimitHandler, rate_limit_handler +from neo4j_graphrag.llm.rate_limit import RateLimitHandler try: from mistralai import Mistral @@ -55,8 +55,7 @@ def __init__( self.model = model self.mistral_client = Mistral(api_key=api_key, **kwargs) - @rate_limit_handler - def embed_query(self, text: str, **kwargs: Any) -> list[float]: + def _embed_query(self, text: str, **kwargs: Any) -> list[float]: """ Generate embeddings for a given query using a Mistral AI text embedding model. diff --git a/src/neo4j_graphrag/embeddings/ollama.py b/src/neo4j_graphrag/embeddings/ollama.py index 5b818ee2d..15ad76683 100644 --- a/src/neo4j_graphrag/embeddings/ollama.py +++ b/src/neo4j_graphrag/embeddings/ollama.py @@ -19,7 +19,7 @@ from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.exceptions import EmbeddingsGenerationError -from neo4j_graphrag.llm.rate_limit import RateLimitHandler, rate_limit_handler +from neo4j_graphrag.llm.rate_limit import RateLimitHandler class OllamaEmbeddings(Embedder): @@ -48,8 +48,7 @@ def __init__( self.model = model self.client = ollama.Client(**kwargs) - @rate_limit_handler - def embed_query(self, text: str, **kwargs: Any) -> list[float]: + def _embed_query(self, text: str, **kwargs: Any) -> list[float]: """ Generate embeddings for a given query using an Ollama text embedding model. diff --git a/src/neo4j_graphrag/embeddings/openai.py b/src/neo4j_graphrag/embeddings/openai.py index 880d8d99e..02cbeb3b3 100644 --- a/src/neo4j_graphrag/embeddings/openai.py +++ b/src/neo4j_graphrag/embeddings/openai.py @@ -20,7 +20,7 @@ from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.exceptions import EmbeddingsGenerationError -from neo4j_graphrag.llm.rate_limit import RateLimitHandler, rate_limit_handler +from neo4j_graphrag.llm.rate_limit import RateLimitHandler if TYPE_CHECKING: import openai @@ -59,8 +59,7 @@ def _initialize_client(self, **kwargs: Any) -> Any: """ pass - @rate_limit_handler - def embed_query(self, text: str, **kwargs: Any) -> list[float]: + def _embed_query(self, text: str, **kwargs: Any) -> list[float]: """ Generate embeddings for a given query using an OpenAI text embedding model. diff --git a/src/neo4j_graphrag/embeddings/sentence_transformers.py b/src/neo4j_graphrag/embeddings/sentence_transformers.py index 5368800b3..a6ec0cde1 100644 --- a/src/neo4j_graphrag/embeddings/sentence_transformers.py +++ b/src/neo4j_graphrag/embeddings/sentence_transformers.py @@ -17,7 +17,7 @@ from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.exceptions import EmbeddingsGenerationError -from neo4j_graphrag.llm.rate_limit import RateLimitHandler, rate_limit_handler +from neo4j_graphrag.llm.rate_limit import RateLimitHandler class SentenceTransformerEmbeddings(Embedder): @@ -42,8 +42,7 @@ def __init__( self.np = np self.model = sentence_transformers.SentenceTransformer(model, *args, **kwargs) - @rate_limit_handler - def embed_query(self, text: str) -> Any: + def _embed_query(self, text: str) -> Any: try: result = self.model.encode([text]) diff --git a/src/neo4j_graphrag/embeddings/vertexai.py b/src/neo4j_graphrag/embeddings/vertexai.py index 0bc30c288..bff5357b5 100644 --- a/src/neo4j_graphrag/embeddings/vertexai.py +++ b/src/neo4j_graphrag/embeddings/vertexai.py @@ -18,7 +18,7 @@ from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.exceptions import EmbeddingsGenerationError -from neo4j_graphrag.llm.rate_limit import RateLimitHandler, rate_limit_handler +from neo4j_graphrag.llm.rate_limit import RateLimitHandler try: from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel @@ -52,8 +52,7 @@ def __init__( super().__init__(rate_limit_handler) self.model = TextEmbeddingModel.from_pretrained(model) - @rate_limit_handler - def embed_query( + def _embed_query( self, text: str, task_type: str = "RETRIEVAL_QUERY", **kwargs: Any ) -> list[float]: """ diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 9932f12e5..4ae462a12 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -58,12 +58,12 @@ def embedder() -> Embedder: class RandomEmbedder(Embedder): - def embed_query(self, text: str) -> list[float]: + def _embed_query(self, text: str) -> list[float]: return [random.random() for _ in range(1536)] class BiologyEmbedder(Embedder): - def embed_query(self, text: str) -> list[float]: + def _embed_query(self, text: str) -> list[float]: if text == "biology": return EMBEDDING_BIOLOGY raise ValueError(f"Unknown embedding text: {text}") From a454edd50ef43f69250f569c3ad721dfec5b2690 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 30 Sep 2025 10:04:56 +0200 Subject: [PATCH 10/14] Move rate limit module to utils and generate deprecation warnings --- src/neo4j_graphrag/llm/rate_limit.py | 297 ++++++------------------- src/neo4j_graphrag/utils/rate_limit.py | 254 +++++++++++++++++++++ 2 files changed, 316 insertions(+), 235 deletions(-) create mode 100644 src/neo4j_graphrag/utils/rate_limit.py diff --git a/src/neo4j_graphrag/llm/rate_limit.py b/src/neo4j_graphrag/llm/rate_limit.py index 165de21ee..4461c6d22 100644 --- a/src/neo4j_graphrag/llm/rate_limit.py +++ b/src/neo4j_graphrag/llm/rate_limit.py @@ -12,243 +12,70 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations -import functools -import logging -from abc import ABC, abstractmethod -from typing import Any, Awaitable, Callable, TypeVar - -from neo4j_graphrag.exceptions import RateLimitError - -from tenacity import ( - retry, - stop_after_attempt, - wait_exponential, - wait_random_exponential, - retry_if_exception_type, - before_sleep_log, +""" +Deprecated module: Rate limiting functionality has been moved to neo4j_graphrag.utils.rate_limit. + +This module provides backward compatibility with deprecation warnings. +All new code should import from neo4j_graphrag.utils.rate_limit instead. +""" + +import warnings +from typing import Any + +# Import the actual implementations from the new location +from neo4j_graphrag.utils.rate_limit import ( + RateLimitHandler as _RateLimitHandler, + NoOpRateLimitHandler as _NoOpRateLimitHandler, + RetryRateLimitHandler as _RetryRateLimitHandler, + rate_limit_handler as _rate_limit_handler, + async_rate_limit_handler as _async_rate_limit_handler, + is_rate_limit_error as _is_rate_limit_error, + convert_to_rate_limit_error as _convert_to_rate_limit_error, + DEFAULT_RATE_LIMIT_HANDLER as _DEFAULT_RATE_LIMIT_HANDLER, ) -logger = logging.getLogger(__name__) - -F = TypeVar("F", bound=Callable[..., Any]) -AF = TypeVar("AF", bound=Callable[..., Awaitable[Any]]) - - -class RateLimitHandler(ABC): - """Abstract base class for rate limit handling strategies.""" - - @abstractmethod - def handle_sync(self, func: F) -> F: - """Apply rate limit handling to a synchronous function. - - Args: - func: The function to wrap with rate limit handling. - - Returns: - The wrapped function. - """ - pass - - @abstractmethod - def handle_async(self, func: AF) -> AF: - """Apply rate limit handling to an asynchronous function. - - Args: - func: The async function to wrap with rate limit handling. - - Returns: - The wrapped async function. - """ - pass - - -class NoOpRateLimitHandler(RateLimitHandler): - """A no-op rate limit handler that does not apply any rate limiting.""" - - def handle_sync(self, func: F) -> F: - """Return the function unchanged.""" - return func - - def handle_async(self, func: AF) -> AF: - """Return the async function unchanged.""" - return func - - -class RetryRateLimitHandler(RateLimitHandler): - """Rate limit handler using exponential backoff retry strategy. - - This handler uses tenacity for retry logic with exponential backoff. - - Args: - max_attempts: Maximum number of retry attempts. Defaults to 3. - min_wait: Minimum wait time between retries in seconds. Defaults to 1. - max_wait: Maximum wait time between retries in seconds. Defaults to 60. - multiplier: Exponential backoff multiplier. Defaults to 2. - jitter: Whether to add random jitter to retry delays to prevent thundering herd. Defaults to True. - """ - - def __init__( - self, - max_attempts: int = 3, - min_wait: float = 1.0, - max_wait: float = 60.0, - multiplier: float = 2.0, - jitter: bool = True, - ): - self.max_attempts = max_attempts - self.min_wait = min_wait - self.max_wait = max_wait - self.multiplier = multiplier - self.jitter = jitter - - def _get_wait_strategy(self) -> Any: - """Get the appropriate wait strategy based on jitter setting. - - Returns: - The configured wait strategy for tenacity retry. - """ - if self.jitter: - # Use built-in random exponential backoff with jitter - return wait_random_exponential( - multiplier=self.multiplier, - min=self.min_wait, - max=self.max_wait, - ) - else: - # Use standard exponential backoff without jitter - return wait_exponential( - multiplier=self.multiplier, - min=self.min_wait, - max=self.max_wait, - ) - - def handle_sync(self, func: F) -> F: - """Apply retry logic to a synchronous function.""" - decorator = retry( - retry=retry_if_exception_type(RateLimitError), - stop=stop_after_attempt(self.max_attempts), - wait=self._get_wait_strategy(), - before_sleep=before_sleep_log(logger, logging.WARNING), +def __getattr__(name: str) -> Any: + """Handle deprecated imports with warnings.""" + deprecated_items = { + "RateLimitHandler": _RateLimitHandler, + "NoOpRateLimitHandler": _NoOpRateLimitHandler, + "RetryRateLimitHandler": _RetryRateLimitHandler, + "rate_limit_handler": _rate_limit_handler, + "async_rate_limit_handler": _async_rate_limit_handler, + "is_rate_limit_error": _is_rate_limit_error, + "convert_to_rate_limit_error": _convert_to_rate_limit_error, + "DEFAULT_RATE_LIMIT_HANDLER": _DEFAULT_RATE_LIMIT_HANDLER, + } + + if name in deprecated_items: + warnings.warn( + f"{name} has been moved to neo4j_graphrag.utils.rate_limit. " + f"Please update your imports to use 'from neo4j_graphrag.utils.rate_limit import {name}'.", + DeprecationWarning, + stacklevel=2, ) - return decorator(func) - - def handle_async(self, func: AF) -> AF: - """Apply retry logic to an asynchronous function.""" - decorator = retry( - retry=retry_if_exception_type(RateLimitError), - stop=stop_after_attempt(self.max_attempts), - wait=self._get_wait_strategy(), - before_sleep=before_sleep_log(logger, logging.WARNING), - ) - return decorator(func) - - -def is_rate_limit_error(exception: Exception) -> bool: - """Check if an exception is a rate limit error from any LLM provider or embedder. - - Args: - exception: The exception to check. - - Returns: - True if the exception indicates a rate limit error, False otherwise. - """ - error_type = type(exception).__name__.lower() - exception_str = str(exception).lower() - - # For LLMGenerationError or EmbeddingsGenerationError (which wrap all provider errors), check provider-specific patterns - if error_type in ["llmgenerationerror", "embeddingsgenerationerror"]: - # Check for various rate limit patterns from different providers - rate_limit_patterns = [ - "error code: 429", # Azure OpenAI - "too many requests", # Anthropic, Cohere, MistralAI - "resource exhausted", # VertexAI - "rate limit", # Generic rate limit messages - "429", # Generic rate limit messages - ] - - return any(pattern in exception_str for pattern in rate_limit_patterns) - - return False - - -def convert_to_rate_limit_error(exception: Exception) -> RateLimitError: - """Convert a provider-specific rate limit exception to RateLimitError. - - Args: - exception: The original exception from the LLM provider. - - Returns: - A RateLimitError with the original exception message. - """ - return RateLimitError(f"Rate limit exceeded: {exception}") - - -def rate_limit_handler(func: F) -> F: - """Decorator to apply rate limit handling to synchronous methods. - - This decorator works with instance methods and uses the instance's rate limit handler. - - Args: - func: The function to wrap with rate limit handling. - - Returns: - The wrapped function. - """ - - @functools.wraps(func) - def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - # Use instance handler or default - active_handler = getattr( - self, "_rate_limit_handler", DEFAULT_RATE_LIMIT_HANDLER - ) - - def inner_func() -> Any: - try: - return func(self, *args, **kwargs) - except Exception as e: - if is_rate_limit_error(e): - raise convert_to_rate_limit_error(e) - raise - - return active_handler.handle_sync(inner_func)() - - return wrapper # type: ignore - - -def async_rate_limit_handler(func: AF) -> AF: - """Decorator to apply rate limit handling to asynchronous methods. - - This decorator works with instance methods and uses the instance's rate limit handler. - - Args: - func: The async function to wrap with rate limit handling. - - Returns: - The wrapped async function. - """ - - @functools.wraps(func) - async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - # Use instance handler or default - active_handler = getattr( - self, "_rate_limit_handler", DEFAULT_RATE_LIMIT_HANDLER - ) - - async def inner_func() -> Any: - try: - return await func(self, *args, **kwargs) - except Exception as e: - if is_rate_limit_error(e): - raise convert_to_rate_limit_error(e) - raise - - return await active_handler.handle_async(inner_func)() - - return wrapper # type: ignore - - -# Default rate limit handler instance -DEFAULT_RATE_LIMIT_HANDLER = RetryRateLimitHandler() + return deprecated_items[name] + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +# For backward compatibility, also expose the deprecated items at module level +# This handles cases where users do: import neo4j_graphrag.llm.rate_limit; rate_limit.RateLimitHandler +RateLimitHandler = _RateLimitHandler +NoOpRateLimitHandler = _NoOpRateLimitHandler +RetryRateLimitHandler = _RetryRateLimitHandler +rate_limit_handler = _rate_limit_handler +async_rate_limit_handler = _async_rate_limit_handler +is_rate_limit_error = _is_rate_limit_error +convert_to_rate_limit_error = _convert_to_rate_limit_error +DEFAULT_RATE_LIMIT_HANDLER = _DEFAULT_RATE_LIMIT_HANDLER + +# Issue deprecation warnings for module-level access +warnings.warn( + "The neo4j_graphrag.llm.rate_limit module has been moved to neo4j_graphrag.utils.rate_limit. " + "Please update your imports to use 'from neo4j_graphrag.utils.rate_limit import ...'.", + DeprecationWarning, + stacklevel=2, +) diff --git a/src/neo4j_graphrag/utils/rate_limit.py b/src/neo4j_graphrag/utils/rate_limit.py new file mode 100644 index 000000000..165de21ee --- /dev/null +++ b/src/neo4j_graphrag/utils/rate_limit.py @@ -0,0 +1,254 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import functools +import logging +from abc import ABC, abstractmethod +from typing import Any, Awaitable, Callable, TypeVar + +from neo4j_graphrag.exceptions import RateLimitError + +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + wait_random_exponential, + retry_if_exception_type, + before_sleep_log, +) + + +logger = logging.getLogger(__name__) + +F = TypeVar("F", bound=Callable[..., Any]) +AF = TypeVar("AF", bound=Callable[..., Awaitable[Any]]) + + +class RateLimitHandler(ABC): + """Abstract base class for rate limit handling strategies.""" + + @abstractmethod + def handle_sync(self, func: F) -> F: + """Apply rate limit handling to a synchronous function. + + Args: + func: The function to wrap with rate limit handling. + + Returns: + The wrapped function. + """ + pass + + @abstractmethod + def handle_async(self, func: AF) -> AF: + """Apply rate limit handling to an asynchronous function. + + Args: + func: The async function to wrap with rate limit handling. + + Returns: + The wrapped async function. + """ + pass + + +class NoOpRateLimitHandler(RateLimitHandler): + """A no-op rate limit handler that does not apply any rate limiting.""" + + def handle_sync(self, func: F) -> F: + """Return the function unchanged.""" + return func + + def handle_async(self, func: AF) -> AF: + """Return the async function unchanged.""" + return func + + +class RetryRateLimitHandler(RateLimitHandler): + """Rate limit handler using exponential backoff retry strategy. + + This handler uses tenacity for retry logic with exponential backoff. + + Args: + max_attempts: Maximum number of retry attempts. Defaults to 3. + min_wait: Minimum wait time between retries in seconds. Defaults to 1. + max_wait: Maximum wait time between retries in seconds. Defaults to 60. + multiplier: Exponential backoff multiplier. Defaults to 2. + jitter: Whether to add random jitter to retry delays to prevent thundering herd. Defaults to True. + """ + + def __init__( + self, + max_attempts: int = 3, + min_wait: float = 1.0, + max_wait: float = 60.0, + multiplier: float = 2.0, + jitter: bool = True, + ): + self.max_attempts = max_attempts + self.min_wait = min_wait + self.max_wait = max_wait + self.multiplier = multiplier + self.jitter = jitter + + def _get_wait_strategy(self) -> Any: + """Get the appropriate wait strategy based on jitter setting. + + Returns: + The configured wait strategy for tenacity retry. + """ + if self.jitter: + # Use built-in random exponential backoff with jitter + return wait_random_exponential( + multiplier=self.multiplier, + min=self.min_wait, + max=self.max_wait, + ) + else: + # Use standard exponential backoff without jitter + return wait_exponential( + multiplier=self.multiplier, + min=self.min_wait, + max=self.max_wait, + ) + + def handle_sync(self, func: F) -> F: + """Apply retry logic to a synchronous function.""" + decorator = retry( + retry=retry_if_exception_type(RateLimitError), + stop=stop_after_attempt(self.max_attempts), + wait=self._get_wait_strategy(), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + return decorator(func) + + def handle_async(self, func: AF) -> AF: + """Apply retry logic to an asynchronous function.""" + decorator = retry( + retry=retry_if_exception_type(RateLimitError), + stop=stop_after_attempt(self.max_attempts), + wait=self._get_wait_strategy(), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + return decorator(func) + + +def is_rate_limit_error(exception: Exception) -> bool: + """Check if an exception is a rate limit error from any LLM provider or embedder. + + Args: + exception: The exception to check. + + Returns: + True if the exception indicates a rate limit error, False otherwise. + """ + error_type = type(exception).__name__.lower() + exception_str = str(exception).lower() + + # For LLMGenerationError or EmbeddingsGenerationError (which wrap all provider errors), check provider-specific patterns + if error_type in ["llmgenerationerror", "embeddingsgenerationerror"]: + # Check for various rate limit patterns from different providers + rate_limit_patterns = [ + "error code: 429", # Azure OpenAI + "too many requests", # Anthropic, Cohere, MistralAI + "resource exhausted", # VertexAI + "rate limit", # Generic rate limit messages + "429", # Generic rate limit messages + ] + + return any(pattern in exception_str for pattern in rate_limit_patterns) + + return False + + +def convert_to_rate_limit_error(exception: Exception) -> RateLimitError: + """Convert a provider-specific rate limit exception to RateLimitError. + + Args: + exception: The original exception from the LLM provider. + + Returns: + A RateLimitError with the original exception message. + """ + return RateLimitError(f"Rate limit exceeded: {exception}") + + +def rate_limit_handler(func: F) -> F: + """Decorator to apply rate limit handling to synchronous methods. + + This decorator works with instance methods and uses the instance's rate limit handler. + + Args: + func: The function to wrap with rate limit handling. + + Returns: + The wrapped function. + """ + + @functools.wraps(func) + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + # Use instance handler or default + active_handler = getattr( + self, "_rate_limit_handler", DEFAULT_RATE_LIMIT_HANDLER + ) + + def inner_func() -> Any: + try: + return func(self, *args, **kwargs) + except Exception as e: + if is_rate_limit_error(e): + raise convert_to_rate_limit_error(e) + raise + + return active_handler.handle_sync(inner_func)() + + return wrapper # type: ignore + + +def async_rate_limit_handler(func: AF) -> AF: + """Decorator to apply rate limit handling to asynchronous methods. + + This decorator works with instance methods and uses the instance's rate limit handler. + + Args: + func: The async function to wrap with rate limit handling. + + Returns: + The wrapped async function. + """ + + @functools.wraps(func) + async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + # Use instance handler or default + active_handler = getattr( + self, "_rate_limit_handler", DEFAULT_RATE_LIMIT_HANDLER + ) + + async def inner_func() -> Any: + try: + return await func(self, *args, **kwargs) + except Exception as e: + if is_rate_limit_error(e): + raise convert_to_rate_limit_error(e) + raise + + return await active_handler.handle_async(inner_func)() + + return wrapper # type: ignore + + +# Default rate limit handler instance +DEFAULT_RATE_LIMIT_HANDLER = RetryRateLimitHandler() From 00eea4825175131fceefa7cd55d8a624577cdceb Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 30 Sep 2025 10:07:06 +0200 Subject: [PATCH 11/14] Refactor modules using rate limit handling --- src/neo4j_graphrag/embeddings/base.py | 2 +- src/neo4j_graphrag/embeddings/cohere.py | 2 +- src/neo4j_graphrag/embeddings/mistral.py | 2 +- src/neo4j_graphrag/embeddings/ollama.py | 2 +- src/neo4j_graphrag/embeddings/openai.py | 2 +- src/neo4j_graphrag/embeddings/sentence_transformers.py | 2 +- src/neo4j_graphrag/embeddings/vertexai.py | 2 +- src/neo4j_graphrag/llm/__init__.py | 2 +- src/neo4j_graphrag/llm/anthropic_llm.py | 2 +- src/neo4j_graphrag/llm/base.py | 4 ++-- src/neo4j_graphrag/llm/cohere_llm.py | 2 +- src/neo4j_graphrag/llm/mistralai_llm.py | 2 +- src/neo4j_graphrag/llm/ollama_llm.py | 6 +++++- src/neo4j_graphrag/llm/openai_llm.py | 6 +++++- src/neo4j_graphrag/llm/vertexai_llm.py | 2 +- tests/unit/llm/test_rate_limit.py | 2 +- 16 files changed, 25 insertions(+), 17 deletions(-) diff --git a/src/neo4j_graphrag/embeddings/base.py b/src/neo4j_graphrag/embeddings/base.py index dae738476..02e5b7a51 100644 --- a/src/neo4j_graphrag/embeddings/base.py +++ b/src/neo4j_graphrag/embeddings/base.py @@ -17,7 +17,7 @@ from abc import ABC, abstractmethod from typing import Optional -from neo4j_graphrag.llm.rate_limit import ( +from neo4j_graphrag.utils.rate_limit import ( DEFAULT_RATE_LIMIT_HANDLER, RateLimitHandler, rate_limit_handler, diff --git a/src/neo4j_graphrag/embeddings/cohere.py b/src/neo4j_graphrag/embeddings/cohere.py index e6185843d..6d89fcca0 100644 --- a/src/neo4j_graphrag/embeddings/cohere.py +++ b/src/neo4j_graphrag/embeddings/cohere.py @@ -18,7 +18,7 @@ from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.exceptions import EmbeddingsGenerationError -from neo4j_graphrag.llm.rate_limit import RateLimitHandler +from neo4j_graphrag.utils.rate_limit import RateLimitHandler try: import cohere diff --git a/src/neo4j_graphrag/embeddings/mistral.py b/src/neo4j_graphrag/embeddings/mistral.py index 768c83c24..2b1c3d284 100644 --- a/src/neo4j_graphrag/embeddings/mistral.py +++ b/src/neo4j_graphrag/embeddings/mistral.py @@ -20,7 +20,7 @@ from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.exceptions import EmbeddingsGenerationError -from neo4j_graphrag.llm.rate_limit import RateLimitHandler +from neo4j_graphrag.utils.rate_limit import RateLimitHandler try: from mistralai import Mistral diff --git a/src/neo4j_graphrag/embeddings/ollama.py b/src/neo4j_graphrag/embeddings/ollama.py index 15ad76683..e70fe96be 100644 --- a/src/neo4j_graphrag/embeddings/ollama.py +++ b/src/neo4j_graphrag/embeddings/ollama.py @@ -19,7 +19,7 @@ from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.exceptions import EmbeddingsGenerationError -from neo4j_graphrag.llm.rate_limit import RateLimitHandler +from neo4j_graphrag.utils.rate_limit import RateLimitHandler class OllamaEmbeddings(Embedder): diff --git a/src/neo4j_graphrag/embeddings/openai.py b/src/neo4j_graphrag/embeddings/openai.py index 02cbeb3b3..9bcf5df70 100644 --- a/src/neo4j_graphrag/embeddings/openai.py +++ b/src/neo4j_graphrag/embeddings/openai.py @@ -20,7 +20,7 @@ from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.exceptions import EmbeddingsGenerationError -from neo4j_graphrag.llm.rate_limit import RateLimitHandler +from neo4j_graphrag.utils.rate_limit import RateLimitHandler if TYPE_CHECKING: import openai diff --git a/src/neo4j_graphrag/embeddings/sentence_transformers.py b/src/neo4j_graphrag/embeddings/sentence_transformers.py index a6ec0cde1..8dca9c4f6 100644 --- a/src/neo4j_graphrag/embeddings/sentence_transformers.py +++ b/src/neo4j_graphrag/embeddings/sentence_transformers.py @@ -17,7 +17,7 @@ from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.exceptions import EmbeddingsGenerationError -from neo4j_graphrag.llm.rate_limit import RateLimitHandler +from neo4j_graphrag.utils.rate_limit import RateLimitHandler class SentenceTransformerEmbeddings(Embedder): diff --git a/src/neo4j_graphrag/embeddings/vertexai.py b/src/neo4j_graphrag/embeddings/vertexai.py index bff5357b5..e1792816f 100644 --- a/src/neo4j_graphrag/embeddings/vertexai.py +++ b/src/neo4j_graphrag/embeddings/vertexai.py @@ -18,7 +18,7 @@ from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.exceptions import EmbeddingsGenerationError -from neo4j_graphrag.llm.rate_limit import RateLimitHandler +from neo4j_graphrag.utils.rate_limit import RateLimitHandler try: from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel diff --git a/src/neo4j_graphrag/llm/__init__.py b/src/neo4j_graphrag/llm/__init__.py index 3c4f65d9a..7b84d92ca 100644 --- a/src/neo4j_graphrag/llm/__init__.py +++ b/src/neo4j_graphrag/llm/__init__.py @@ -18,7 +18,7 @@ from .mistralai_llm import MistralAILLM from .ollama_llm import OllamaLLM from .openai_llm import AzureOpenAILLM, OpenAILLM -from .rate_limit import ( +from neo4j_graphrag.utils.rate_limit import ( RateLimitHandler, NoOpRateLimitHandler, RetryRateLimitHandler, diff --git a/src/neo4j_graphrag/llm/anthropic_llm.py b/src/neo4j_graphrag/llm/anthropic_llm.py index 6bafef85b..21560d3f2 100644 --- a/src/neo4j_graphrag/llm/anthropic_llm.py +++ b/src/neo4j_graphrag/llm/anthropic_llm.py @@ -19,7 +19,7 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface -from neo4j_graphrag.llm.rate_limit import ( +from neo4j_graphrag.utils.rate_limit import ( RateLimitHandler, rate_limit_handler, async_rate_limit_handler, diff --git a/src/neo4j_graphrag/llm/base.py b/src/neo4j_graphrag/llm/base.py index cca710bc9..ff7af1c70 100644 --- a/src/neo4j_graphrag/llm/base.py +++ b/src/neo4j_graphrag/llm/base.py @@ -21,13 +21,13 @@ from neo4j_graphrag.types import LLMMessage from .types import LLMResponse, ToolCallResponse -from .rate_limit import ( +from neo4j_graphrag.utils.rate_limit import ( DEFAULT_RATE_LIMIT_HANDLER, ) from neo4j_graphrag.tool import Tool -from .rate_limit import RateLimitHandler +from neo4j_graphrag.utils.rate_limit import RateLimitHandler class LLMInterface(ABC): diff --git a/src/neo4j_graphrag/llm/cohere_llm.py b/src/neo4j_graphrag/llm/cohere_llm.py index 7c3905500..2e3ca0cea 100644 --- a/src/neo4j_graphrag/llm/cohere_llm.py +++ b/src/neo4j_graphrag/llm/cohere_llm.py @@ -20,7 +20,7 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface -from neo4j_graphrag.llm.rate_limit import ( +from neo4j_graphrag.utils.rate_limit import ( RateLimitHandler, rate_limit_handler, async_rate_limit_handler, diff --git a/src/neo4j_graphrag/llm/mistralai_llm.py b/src/neo4j_graphrag/llm/mistralai_llm.py index ae2a6312f..3fa8663ae 100644 --- a/src/neo4j_graphrag/llm/mistralai_llm.py +++ b/src/neo4j_graphrag/llm/mistralai_llm.py @@ -21,7 +21,7 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface -from neo4j_graphrag.llm.rate_limit import ( +from neo4j_graphrag.utils.rate_limit import ( RateLimitHandler, rate_limit_handler, async_rate_limit_handler, diff --git a/src/neo4j_graphrag/llm/ollama_llm.py b/src/neo4j_graphrag/llm/ollama_llm.py index 214640625..94541e033 100644 --- a/src/neo4j_graphrag/llm/ollama_llm.py +++ b/src/neo4j_graphrag/llm/ollama_llm.py @@ -24,7 +24,11 @@ from neo4j_graphrag.types import LLMMessage from .base import LLMInterface -from .rate_limit import RateLimitHandler, rate_limit_handler, async_rate_limit_handler +from neo4j_graphrag.utils.rate_limit import ( + RateLimitHandler, + rate_limit_handler, + async_rate_limit_handler, +) from .types import ( BaseMessage, LLMResponse, diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index d74c83dcf..afdf0234d 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -35,7 +35,11 @@ from ..exceptions import LLMGenerationError from .base import LLMInterface -from .rate_limit import RateLimitHandler, rate_limit_handler, async_rate_limit_handler +from neo4j_graphrag.utils.rate_limit import ( + RateLimitHandler, + rate_limit_handler, + async_rate_limit_handler, +) from .types import ( BaseMessage, LLMResponse, diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index 0b4926978..b9f1e40e8 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -19,7 +19,7 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface -from neo4j_graphrag.llm.rate_limit import ( +from neo4j_graphrag.utils.rate_limit import ( RateLimitHandler, rate_limit_handler, async_rate_limit_handler, diff --git a/tests/unit/llm/test_rate_limit.py b/tests/unit/llm/test_rate_limit.py index f1f4b133b..51bc80b14 100644 --- a/tests/unit/llm/test_rate_limit.py +++ b/tests/unit/llm/test_rate_limit.py @@ -19,7 +19,7 @@ from unittest.mock import Mock from tenacity import RetryError -from neo4j_graphrag.llm.rate_limit import ( +from neo4j_graphrag.utils.rate_limit import ( RateLimitHandler, NoOpRateLimitHandler, DEFAULT_RATE_LIMIT_HANDLER, From 3d2f0a37692a1961da61f95a1643526463ea9e1d Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 30 Sep 2025 10:07:23 +0200 Subject: [PATCH 12/14] Update docs and examples --- docs/source/api.rst | 6 +++--- docs/source/user_guide_rag.rst | 6 +++--- examples/customize/llms/custom_llm.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index 4066348b8..f891b4dec 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -359,19 +359,19 @@ Rate Limiting RateLimitHandler ---------------- -.. autoclass:: neo4j_graphrag.llm.rate_limit.RateLimitHandler +.. autoclass:: neo4j_graphrag.utils.rate_limit.RateLimitHandler :members: RetryRateLimitHandler --------------------- -.. autoclass:: neo4j_graphrag.llm.rate_limit.RetryRateLimitHandler +.. autoclass:: neo4j_graphrag.utils.rate_limit.RetryRateLimitHandler :members: NoOpRateLimitHandler -------------------- -.. autoclass:: neo4j_graphrag.llm.rate_limit.NoOpRateLimitHandler +.. autoclass:: neo4j_graphrag.utils.rate_limit.NoOpRateLimitHandler :members: diff --git a/docs/source/user_guide_rag.rst b/docs/source/user_guide_rag.rst index f381b963f..ce0e4fc83 100644 --- a/docs/source/user_guide_rag.rst +++ b/docs/source/user_guide_rag.rst @@ -327,7 +327,7 @@ Rate limiting is enabled by default for all LLM instances with the following con .. code:: python from neo4j_graphrag.llm import OpenAILLM - from neo4j_graphrag.llm.rate_limit import RetryRateLimitHandler + from neo4j_graphrag.utils.rate_limit import RetryRateLimitHandler # Customize rate limiting parameters llm = OpenAILLM( @@ -348,7 +348,7 @@ You can customize the rate limiting behavior by creating your own rate limit han .. code:: python from neo4j_graphrag.llm import AnthropicLLM - from neo4j_graphrag.llm.rate_limit import RateLimitHandler + from neo4j_graphrag.utils.rate_limit import RateLimitHandler class CustomRateLimitHandler(RateLimitHandler): """Implement your custom rate limiting strategy.""" @@ -536,7 +536,7 @@ All embedder implementations include automatic rate limiting that uses retry log .. code:: python from neo4j_graphrag.embeddings import OpenAIEmbeddings - from neo4j_graphrag.llm.rate_limit import RetryRateLimitHandler, NoOpRateLimitHandler + from neo4j_graphrag.utils.rate_limit import RetryRateLimitHandler, NoOpRateLimitHandler # Default rate limiting (automatically enabled) embedder = OpenAIEmbeddings(model="text-embedding-3-large") diff --git a/examples/customize/llms/custom_llm.py b/examples/customize/llms/custom_llm.py index 0eecfd878..86b3cb993 100644 --- a/examples/customize/llms/custom_llm.py +++ b/examples/customize/llms/custom_llm.py @@ -3,7 +3,7 @@ from typing import Any, Awaitable, Callable, List, Optional, TypeVar, Union from neo4j_graphrag.llm import LLMInterface, LLMResponse -from neo4j_graphrag.llm.rate_limit import ( +from neo4j_graphrag.utils.rate_limit import ( RateLimitHandler, # rate_limit_handler, # async_rate_limit_handler, From c1bb8c9a19fd58c70ef465885a2c14145b8c555c Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 30 Sep 2025 14:12:16 +0200 Subject: [PATCH 13/14] Fix refactoring --- src/neo4j_graphrag/llm/__init__.py | 53 +++++++++++++++++++++------- src/neo4j_graphrag/llm/rate_limit.py | 17 ++------- 2 files changed, 43 insertions(+), 27 deletions(-) diff --git a/src/neo4j_graphrag/llm/__init__.py b/src/neo4j_graphrag/llm/__init__.py index 7b84d92ca..d34984b02 100644 --- a/src/neo4j_graphrag/llm/__init__.py +++ b/src/neo4j_graphrag/llm/__init__.py @@ -12,22 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings +from typing import Any + from .anthropic_llm import AnthropicLLM from .base import LLMInterface from .cohere_llm import CohereLLM from .mistralai_llm import MistralAILLM from .ollama_llm import OllamaLLM from .openai_llm import AzureOpenAILLM, OpenAILLM -from neo4j_graphrag.utils.rate_limit import ( - RateLimitHandler, - NoOpRateLimitHandler, - RetryRateLimitHandler, - rate_limit_handler, - async_rate_limit_handler, -) from .types import LLMResponse from .vertexai_llm import VertexAILLM + __all__ = [ "AnthropicLLM", "CohereLLM", @@ -38,10 +35,40 @@ "VertexAILLM", "AzureOpenAILLM", "MistralAILLM", - # Rate limiting components - "RateLimitHandler", - "NoOpRateLimitHandler", - "RetryRateLimitHandler", - "rate_limit_handler", - "async_rate_limit_handler", ] + + +def __getattr__(name: str) -> Any: + """Handle deprecated imports with warnings.""" + from neo4j_graphrag.utils.rate_limit import ( + RateLimitHandler, + NoOpRateLimitHandler, + RetryRateLimitHandler, + rate_limit_handler, + async_rate_limit_handler, + is_rate_limit_error, + convert_to_rate_limit_error, + DEFAULT_RATE_LIMIT_HANDLER, + ) + + deprecated_items = { + "RateLimitHandler": RateLimitHandler, + "NoOpRateLimitHandler": NoOpRateLimitHandler, + "RetryRateLimitHandler": RetryRateLimitHandler, + "rate_limit_handler": rate_limit_handler, + "async_rate_limit_handler": async_rate_limit_handler, + "is_rate_limit_error": is_rate_limit_error, + "convert_to_rate_limit_error": convert_to_rate_limit_error, + "DEFAULT_RATE_LIMIT_HANDLER": DEFAULT_RATE_LIMIT_HANDLER, + } + + if name in deprecated_items: + warnings.warn( + f"{name} has been moved to neo4j_graphrag.utils.rate_limit. " + f"Please update your imports to use 'from neo4j_graphrag.utils.rate_limit import {name}'.", + DeprecationWarning, + stacklevel=2, + ) + return deprecated_items[name] + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/neo4j_graphrag/llm/rate_limit.py b/src/neo4j_graphrag/llm/rate_limit.py index 4461c6d22..c7e706877 100644 --- a/src/neo4j_graphrag/llm/rate_limit.py +++ b/src/neo4j_graphrag/llm/rate_limit.py @@ -48,7 +48,7 @@ def __getattr__(name: str) -> Any: "convert_to_rate_limit_error": _convert_to_rate_limit_error, "DEFAULT_RATE_LIMIT_HANDLER": _DEFAULT_RATE_LIMIT_HANDLER, } - + if name in deprecated_items: warnings.warn( f"{name} has been moved to neo4j_graphrag.utils.rate_limit. " @@ -57,22 +57,11 @@ def __getattr__(name: str) -> Any: stacklevel=2, ) return deprecated_items[name] - + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") -# For backward compatibility, also expose the deprecated items at module level -# This handles cases where users do: import neo4j_graphrag.llm.rate_limit; rate_limit.RateLimitHandler -RateLimitHandler = _RateLimitHandler -NoOpRateLimitHandler = _NoOpRateLimitHandler -RetryRateLimitHandler = _RetryRateLimitHandler -rate_limit_handler = _rate_limit_handler -async_rate_limit_handler = _async_rate_limit_handler -is_rate_limit_error = _is_rate_limit_error -convert_to_rate_limit_error = _convert_to_rate_limit_error -DEFAULT_RATE_LIMIT_HANDLER = _DEFAULT_RATE_LIMIT_HANDLER - -# Issue deprecation warnings for module-level access +# Issue a single deprecation warning when the module is imported warnings.warn( "The neo4j_graphrag.llm.rate_limit module has been moved to neo4j_graphrag.utils.rate_limit. " "Please update your imports to use 'from neo4j_graphrag.utils.rate_limit import ...'.", From c2857cb556aa44c34967c232c1513ce646f03b46 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 30 Sep 2025 14:15:03 +0200 Subject: [PATCH 14/14] Ruff --- src/neo4j_graphrag/llm/__init__.py | 6 +++--- src/neo4j_graphrag/llm/rate_limit.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/neo4j_graphrag/llm/__init__.py b/src/neo4j_graphrag/llm/__init__.py index d34984b02..f6d63376f 100644 --- a/src/neo4j_graphrag/llm/__init__.py +++ b/src/neo4j_graphrag/llm/__init__.py @@ -50,7 +50,7 @@ def __getattr__(name: str) -> Any: convert_to_rate_limit_error, DEFAULT_RATE_LIMIT_HANDLER, ) - + deprecated_items = { "RateLimitHandler": RateLimitHandler, "NoOpRateLimitHandler": NoOpRateLimitHandler, @@ -61,7 +61,7 @@ def __getattr__(name: str) -> Any: "convert_to_rate_limit_error": convert_to_rate_limit_error, "DEFAULT_RATE_LIMIT_HANDLER": DEFAULT_RATE_LIMIT_HANDLER, } - + if name in deprecated_items: warnings.warn( f"{name} has been moved to neo4j_graphrag.utils.rate_limit. " @@ -70,5 +70,5 @@ def __getattr__(name: str) -> Any: stacklevel=2, ) return deprecated_items[name] - + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/neo4j_graphrag/llm/rate_limit.py b/src/neo4j_graphrag/llm/rate_limit.py index c7e706877..d2a82c595 100644 --- a/src/neo4j_graphrag/llm/rate_limit.py +++ b/src/neo4j_graphrag/llm/rate_limit.py @@ -48,7 +48,7 @@ def __getattr__(name: str) -> Any: "convert_to_rate_limit_error": _convert_to_rate_limit_error, "DEFAULT_RATE_LIMIT_HANDLER": _DEFAULT_RATE_LIMIT_HANDLER, } - + if name in deprecated_items: warnings.warn( f"{name} has been moved to neo4j_graphrag.utils.rate_limit. " @@ -57,7 +57,7 @@ def __getattr__(name: str) -> Any: stacklevel=2, ) return deprecated_items[name] - + raise AttributeError(f"module {__name__!r} has no attribute {name!r}")