diff --git a/CHANGELOG.md b/CHANGELOG.md index 35a2154fd..1c0b6c9b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ ## Next +### Added + +- Added automatic rate limiting with retry logic and exponential backoff for all Embedding providers using tenacity. The `RateLimitHandler` interface allows for custom rate limiting strategies, including the ability to disable rate limiting entirely. + ## 1.10.0 ### Added diff --git a/docs/source/api.rst b/docs/source/api.rst index 4066348b8..f891b4dec 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -359,19 +359,19 @@ Rate Limiting RateLimitHandler ---------------- -.. autoclass:: neo4j_graphrag.llm.rate_limit.RateLimitHandler +.. autoclass:: neo4j_graphrag.utils.rate_limit.RateLimitHandler :members: RetryRateLimitHandler --------------------- -.. autoclass:: neo4j_graphrag.llm.rate_limit.RetryRateLimitHandler +.. autoclass:: neo4j_graphrag.utils.rate_limit.RetryRateLimitHandler :members: NoOpRateLimitHandler -------------------- -.. autoclass:: neo4j_graphrag.llm.rate_limit.NoOpRateLimitHandler +.. autoclass:: neo4j_graphrag.utils.rate_limit.NoOpRateLimitHandler :members: diff --git a/docs/source/user_guide_rag.rst b/docs/source/user_guide_rag.rst index 0a6afed34..ce0e4fc83 100644 --- a/docs/source/user_guide_rag.rst +++ b/docs/source/user_guide_rag.rst @@ -327,7 +327,7 @@ Rate limiting is enabled by default for all LLM instances with the following con .. code:: python from neo4j_graphrag.llm import OpenAILLM - from neo4j_graphrag.llm.rate_limit import RetryRateLimitHandler + from neo4j_graphrag.utils.rate_limit import RetryRateLimitHandler # Customize rate limiting parameters llm = OpenAILLM( @@ -348,7 +348,7 @@ You can customize the rate limiting behavior by creating your own rate limit han .. code:: python from neo4j_graphrag.llm import AnthropicLLM - from neo4j_graphrag.llm.rate_limit import RateLimitHandler + from neo4j_graphrag.utils.rate_limit import RateLimitHandler class CustomRateLimitHandler(RateLimitHandler): """Implement your custom rate limiting strategy.""" @@ -528,6 +528,37 @@ The `OpenAIEmbeddings` was illustrated previously. Here is how to use the `Sente If another embedder is desired, a custom embedder can be created, using the `Embedder` interface. +Embedder Rate Limiting +---------------------- + +All embedder implementations include automatic rate limiting that uses retry logic with exponential backoff by default, similar to LLM implementations. This feature helps handle API rate limits from embedding providers gracefully. + +.. code:: python + + from neo4j_graphrag.embeddings import OpenAIEmbeddings + from neo4j_graphrag.utils.rate_limit import RetryRateLimitHandler, NoOpRateLimitHandler + + # Default rate limiting (automatically enabled) + embedder = OpenAIEmbeddings(model="text-embedding-3-large") + + # Custom rate limiting configuration + embedder = OpenAIEmbeddings( + model="text-embedding-3-large", + rate_limit_handler=RetryRateLimitHandler( + max_attempts=5, + min_wait=2.0, + max_wait=120.0 + ) + ) + + # Disable rate limiting + embedder = OpenAIEmbeddings( + model="text-embedding-3-large", + rate_limit_handler=NoOpRateLimitHandler() + ) + +The rate limiting configuration works the same way as for LLMs. See the :ref:`Rate Limit Handling ` section above for more details on customization options. + Other Vector Retriever Configuration ---------------------------------------- diff --git a/examples/customize/embeddings/custom_embeddings.py b/examples/customize/embeddings/custom_embeddings.py index 5b15eb0f7..e77127359 100644 --- a/examples/customize/embeddings/custom_embeddings.py +++ b/examples/customize/embeddings/custom_embeddings.py @@ -8,7 +8,7 @@ class CustomEmbeddings(Embedder): def __init__(self, dimension: int = 10, **kwargs: Any): self.dimension = dimension - def embed_query(self, input: str) -> list[float]: + def _embed_query(self, input: str) -> list[float]: return [random.random() for _ in range(self.dimension)] diff --git a/examples/customize/llms/custom_llm.py b/examples/customize/llms/custom_llm.py index 0eecfd878..86b3cb993 100644 --- a/examples/customize/llms/custom_llm.py +++ b/examples/customize/llms/custom_llm.py @@ -3,7 +3,7 @@ from typing import Any, Awaitable, Callable, List, Optional, TypeVar, Union from neo4j_graphrag.llm import LLMInterface, LLMResponse -from neo4j_graphrag.llm.rate_limit import ( +from neo4j_graphrag.utils.rate_limit import ( RateLimitHandler, # rate_limit_handler, # async_rate_limit_handler, diff --git a/examples/customize/retrievers/hybrid_retrievers/hybrid_cypher_search.py b/examples/customize/retrievers/hybrid_retrievers/hybrid_cypher_search.py index e1d59e379..6268e82fd 100644 --- a/examples/customize/retrievers/hybrid_retrievers/hybrid_cypher_search.py +++ b/examples/customize/retrievers/hybrid_retrievers/hybrid_cypher_search.py @@ -20,7 +20,7 @@ # Create Embedder object class CustomEmbedder(Embedder): - def embed_query(self, text: str) -> list[float]: + def _embed_query(self, text: str) -> list[float]: return [random() for _ in range(DIMENSION)] diff --git a/examples/customize/retrievers/hybrid_retrievers/hybrid_search.py b/examples/customize/retrievers/hybrid_retrievers/hybrid_search.py index 69940596a..b9b9dd792 100644 --- a/examples/customize/retrievers/hybrid_retrievers/hybrid_search.py +++ b/examples/customize/retrievers/hybrid_retrievers/hybrid_search.py @@ -20,7 +20,7 @@ # Create Embedder object class CustomEmbedder(Embedder): - def embed_query(self, text: str) -> list[float]: + def _embed_query(self, text: str) -> list[float]: return [random() for _ in range(DIMENSION)] diff --git a/src/neo4j_graphrag/embeddings/base.py b/src/neo4j_graphrag/embeddings/base.py index 73d6bc9ac..02e5b7a51 100644 --- a/src/neo4j_graphrag/embeddings/base.py +++ b/src/neo4j_graphrag/embeddings/base.py @@ -15,15 +15,31 @@ from __future__ import annotations from abc import ABC, abstractmethod +from typing import Optional + +from neo4j_graphrag.utils.rate_limit import ( + DEFAULT_RATE_LIMIT_HANDLER, + RateLimitHandler, + rate_limit_handler, +) class Embedder(ABC): """ Interface for embedding models. An embedder passed into a retriever must implement this interface. + + Args: + rate_limit_handler (Optional[RateLimitHandler]): Handler for rate limiting. Defaults to retry with exponential backoff. """ - @abstractmethod + def __init__(self, rate_limit_handler: Optional[RateLimitHandler] = None): + if rate_limit_handler is not None: + self._rate_limit_handler = rate_limit_handler + else: + self._rate_limit_handler = DEFAULT_RATE_LIMIT_HANDLER + + @rate_limit_handler def embed_query(self, text: str) -> list[float]: """Embed query text. @@ -33,3 +49,15 @@ def embed_query(self, text: str) -> list[float]: Returns: list[float]: A vector embedding. """ + return self._embed_query(text) + + @abstractmethod + def _embed_query(self, text: str) -> list[float]: + """Embed query text. + + Args: + text (str): Text to convert to vector embedding + + Returns: + list[float]: A vector embedding. + """ diff --git a/src/neo4j_graphrag/embeddings/cohere.py b/src/neo4j_graphrag/embeddings/cohere.py index 63906a5e0..6d89fcca0 100644 --- a/src/neo4j_graphrag/embeddings/cohere.py +++ b/src/neo4j_graphrag/embeddings/cohere.py @@ -14,9 +14,11 @@ # limitations under the License. from __future__ import annotations -from typing import Any +from typing import Any, Optional from neo4j_graphrag.embeddings.base import Embedder +from neo4j_graphrag.exceptions import EmbeddingsGenerationError +from neo4j_graphrag.utils.rate_limit import RateLimitHandler try: import cohere @@ -25,19 +27,30 @@ class CohereEmbeddings(Embedder): - def __init__(self, model: str = "", **kwargs: Any) -> None: + def __init__( + self, + model: str = "", + rate_limit_handler: Optional[RateLimitHandler] = None, + **kwargs: Any, + ) -> None: if cohere is None: raise ImportError( """Could not import cohere python client. Please install it with `pip install "neo4j-graphrag[cohere]"`.""" ) + super().__init__(rate_limit_handler) self.model = model self.client = cohere.Client(**kwargs) - def embed_query(self, text: str, **kwargs: Any) -> list[float]: - response = self.client.embed( - texts=[text], - model=self.model, - **kwargs, - ) - return response.embeddings[0] # type: ignore + def _embed_query(self, text: str, **kwargs: Any) -> list[float]: + try: + response = self.client.embed( + texts=[text], + model=self.model, + **kwargs, + ) + return response.embeddings[0] # type: ignore + except Exception as e: + raise EmbeddingsGenerationError( + f"Failed to generate embedding with Cohere: {e}" + ) from e diff --git a/src/neo4j_graphrag/embeddings/mistral.py b/src/neo4j_graphrag/embeddings/mistral.py index 099430193..2b1c3d284 100644 --- a/src/neo4j_graphrag/embeddings/mistral.py +++ b/src/neo4j_graphrag/embeddings/mistral.py @@ -16,10 +16,11 @@ from __future__ import annotations import os -from typing import Any +from typing import Any, Optional from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.exceptions import EmbeddingsGenerationError +from neo4j_graphrag.utils.rate_limit import RateLimitHandler try: from mistralai import Mistral @@ -36,19 +37,25 @@ class MistralAIEmbeddings(Embedder): model (str): The name of the Mistral AI text embedding model to use. Defaults to "mistral-embed". """ - def __init__(self, model: str = "mistral-embed", **kwargs: Any) -> None: + def __init__( + self, + model: str = "mistral-embed", + rate_limit_handler: Optional[RateLimitHandler] = None, + **kwargs: Any, + ) -> None: if Mistral is None: raise ImportError( """Could not import mistralai. Please install it with `pip install "neo4j-graphrag[mistralai]"`.""" ) + super().__init__(rate_limit_handler) api_key = kwargs.pop("api_key", None) if api_key is None: api_key = os.getenv("MISTRAL_API_KEY", "") self.model = model self.mistral_client = Mistral(api_key=api_key, **kwargs) - def embed_query(self, text: str, **kwargs: Any) -> list[float]: + def _embed_query(self, text: str, **kwargs: Any) -> list[float]: """ Generate embeddings for a given query using a Mistral AI text embedding model. @@ -56,9 +63,15 @@ def embed_query(self, text: str, **kwargs: Any) -> list[float]: text (str): The text to generate an embedding for. **kwargs (Any): Additional keyword arguments to pass to the Mistral AI client. """ - embeddings_batch_response = self.mistral_client.embeddings.create( - model=self.model, inputs=[text], **kwargs - ) + try: + embeddings_batch_response = self.mistral_client.embeddings.create( + model=self.model, inputs=[text], **kwargs + ) + except Exception as e: + raise EmbeddingsGenerationError( + f"Failed to generate embedding with MistralAI: {e}" + ) from e + if embeddings_batch_response is None or not embeddings_batch_response.data: raise EmbeddingsGenerationError("Failed to retrieve embeddings.") diff --git a/src/neo4j_graphrag/embeddings/ollama.py b/src/neo4j_graphrag/embeddings/ollama.py index 78775ba60..e70fe96be 100644 --- a/src/neo4j_graphrag/embeddings/ollama.py +++ b/src/neo4j_graphrag/embeddings/ollama.py @@ -15,10 +15,11 @@ from __future__ import annotations -from typing import Any +from typing import Any, Optional from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.exceptions import EmbeddingsGenerationError +from neo4j_graphrag.utils.rate_limit import RateLimitHandler class OllamaEmbeddings(Embedder): @@ -30,7 +31,12 @@ class OllamaEmbeddings(Embedder): model (str): The name of the Mistral AI text embedding model to use. Defaults to "mistral-embed". """ - def __init__(self, model: str, **kwargs: Any) -> None: + def __init__( + self, + model: str, + rate_limit_handler: Optional[RateLimitHandler] = None, + **kwargs: Any, + ) -> None: try: import ollama except ImportError: @@ -38,10 +44,11 @@ def __init__(self, model: str, **kwargs: Any) -> None: """Could not import ollama python client. Please install it with `pip install "neo4j_graphrag[ollama]"`.""" ) + super().__init__(rate_limit_handler) self.model = model self.client = ollama.Client(**kwargs) - def embed_query(self, text: str, **kwargs: Any) -> list[float]: + def _embed_query(self, text: str, **kwargs: Any) -> list[float]: """ Generate embeddings for a given query using an Ollama text embedding model. diff --git a/src/neo4j_graphrag/embeddings/openai.py b/src/neo4j_graphrag/embeddings/openai.py index 4a4d60387..9bcf5df70 100644 --- a/src/neo4j_graphrag/embeddings/openai.py +++ b/src/neo4j_graphrag/embeddings/openai.py @@ -16,9 +16,11 @@ from __future__ import annotations import abc -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional from neo4j_graphrag.embeddings.base import Embedder +from neo4j_graphrag.exceptions import EmbeddingsGenerationError +from neo4j_graphrag.utils.rate_limit import RateLimitHandler if TYPE_CHECKING: import openai @@ -31,7 +33,12 @@ class BaseOpenAIEmbeddings(Embedder, abc.ABC): client: openai.OpenAI - def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None: + def __init__( + self, + model: str = "text-embedding-ada-002", + rate_limit_handler: Optional[RateLimitHandler] = None, + **kwargs: Any, + ) -> None: try: import openai except ImportError: @@ -39,6 +46,7 @@ def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None """Could not import openai python client. Please install it with `pip install "neo4j-graphrag[openai]"`.""" ) + super().__init__(rate_limit_handler) self.openai = openai self.model = model self.client = self._initialize_client(**kwargs) @@ -51,7 +59,7 @@ def _initialize_client(self, **kwargs: Any) -> Any: """ pass - def embed_query(self, text: str, **kwargs: Any) -> list[float]: + def _embed_query(self, text: str, **kwargs: Any) -> list[float]: """ Generate embeddings for a given query using an OpenAI text embedding model. @@ -59,9 +67,16 @@ def embed_query(self, text: str, **kwargs: Any) -> list[float]: text (str): The text to generate an embedding for. **kwargs (Any): Additional arguments to pass to the OpenAI embedding generation function. """ - response = self.client.embeddings.create(input=text, model=self.model, **kwargs) - embedding: list[float] = response.data[0].embedding - return embedding + try: + response = self.client.embeddings.create( + input=text, model=self.model, **kwargs + ) + embedding: list[float] = response.data[0].embedding + return embedding + except Exception as e: + raise EmbeddingsGenerationError( + f"Failed to generate embedding with OpenAI: {e}" + ) from e class OpenAIEmbeddings(BaseOpenAIEmbeddings): diff --git a/src/neo4j_graphrag/embeddings/sentence_transformers.py b/src/neo4j_graphrag/embeddings/sentence_transformers.py index f30ed5b7f..8dca9c4f6 100644 --- a/src/neo4j_graphrag/embeddings/sentence_transformers.py +++ b/src/neo4j_graphrag/embeddings/sentence_transformers.py @@ -13,14 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional from neo4j_graphrag.embeddings.base import Embedder +from neo4j_graphrag.exceptions import EmbeddingsGenerationError +from neo4j_graphrag.utils.rate_limit import RateLimitHandler class SentenceTransformerEmbeddings(Embedder): def __init__( - self, model: str = "all-MiniLM-L6-v2", *args: Any, **kwargs: Any + self, + model: str = "all-MiniLM-L6-v2", + rate_limit_handler: Optional[RateLimitHandler] = None, + *args: Any, + **kwargs: Any, ) -> None: try: import numpy as np @@ -31,17 +37,26 @@ def __init__( """Could not import sentence_transformers python package. Please install it with `pip install "neo4j-graphrag[sentence-transformers]"`.""" ) + super().__init__(rate_limit_handler) self.torch = torch self.np = np self.model = sentence_transformers.SentenceTransformer(model, *args, **kwargs) - def embed_query(self, text: str) -> Any: - result = self.model.encode([text]) - if isinstance(result, self.torch.Tensor) or isinstance(result, self.np.ndarray): - return result.flatten().tolist() - elif isinstance(result, list) and all( - isinstance(x, self.torch.Tensor) for x in result - ): - return [item for tensor in result for item in tensor.flatten().tolist()] - else: - raise ValueError("Unexpected return type from model encoding") + def _embed_query(self, text: str) -> Any: + try: + result = self.model.encode([text]) + + if isinstance(result, self.torch.Tensor) or isinstance( + result, self.np.ndarray + ): + return result.flatten().tolist() + elif isinstance(result, list) and all( + isinstance(x, self.torch.Tensor) for x in result + ): + return [item for tensor in result for item in tensor.flatten().tolist()] + else: + raise ValueError("Unexpected return type from model encoding") + except Exception as e: + raise EmbeddingsGenerationError( + f"Failed to generate embedding with SentenceTransformer: {e}" + ) from e diff --git a/src/neo4j_graphrag/embeddings/vertexai.py b/src/neo4j_graphrag/embeddings/vertexai.py index cfed3868a..e1792816f 100644 --- a/src/neo4j_graphrag/embeddings/vertexai.py +++ b/src/neo4j_graphrag/embeddings/vertexai.py @@ -14,9 +14,11 @@ # limitations under the License. from __future__ import annotations -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Optional from neo4j_graphrag.embeddings.base import Embedder +from neo4j_graphrag.exceptions import EmbeddingsGenerationError +from neo4j_graphrag.utils.rate_limit import RateLimitHandler try: from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel @@ -37,15 +39,20 @@ class VertexAIEmbeddings(Embedder): model (str): The name of the Vertex AI text embedding model to use. Defaults to "text-embedding-004". """ - def __init__(self, model: str = "text-embedding-004") -> None: + def __init__( + self, + model: str = "text-embedding-004", + rate_limit_handler: Optional[RateLimitHandler] = None, + ) -> None: if TextEmbeddingModel is None: raise ImportError( """Could not import Vertex AI Python client. Please install it with `pip install "neo4j-graphrag[google]"`.""" ) + super().__init__(rate_limit_handler) self.model = TextEmbeddingModel.from_pretrained(model) - def embed_query( + def _embed_query( self, text: str, task_type: str = "RETRIEVAL_QUERY", **kwargs: Any ) -> list[float]: """ @@ -56,7 +63,14 @@ def embed_query( task_type (str): The type of the text embedding task. Defaults to "RETRIEVAL_QUERY". See https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#tasktype for a full list. **kwargs (Any): Additional keyword arguments to pass to the Vertex AI client's get_embeddings method. """ - # type annotation needed for mypy - inputs: list[str | TextEmbeddingInput] = [TextEmbeddingInput(text, task_type)] - embeddings = self.model.get_embeddings(inputs, **kwargs) - return embeddings[0].values + try: + # type annotation needed for mypy + inputs: list[str | TextEmbeddingInput] = [ + TextEmbeddingInput(text, task_type) + ] + embeddings = self.model.get_embeddings(inputs, **kwargs) + return list(embeddings[0].values) + except Exception as e: + raise EmbeddingsGenerationError( + f"Failed to generate embedding with VertexAI: {e}" + ) from e diff --git a/src/neo4j_graphrag/llm/__init__.py b/src/neo4j_graphrag/llm/__init__.py index 3c4f65d9a..f6d63376f 100644 --- a/src/neo4j_graphrag/llm/__init__.py +++ b/src/neo4j_graphrag/llm/__init__.py @@ -12,22 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings +from typing import Any + from .anthropic_llm import AnthropicLLM from .base import LLMInterface from .cohere_llm import CohereLLM from .mistralai_llm import MistralAILLM from .ollama_llm import OllamaLLM from .openai_llm import AzureOpenAILLM, OpenAILLM -from .rate_limit import ( - RateLimitHandler, - NoOpRateLimitHandler, - RetryRateLimitHandler, - rate_limit_handler, - async_rate_limit_handler, -) from .types import LLMResponse from .vertexai_llm import VertexAILLM + __all__ = [ "AnthropicLLM", "CohereLLM", @@ -38,10 +35,40 @@ "VertexAILLM", "AzureOpenAILLM", "MistralAILLM", - # Rate limiting components - "RateLimitHandler", - "NoOpRateLimitHandler", - "RetryRateLimitHandler", - "rate_limit_handler", - "async_rate_limit_handler", ] + + +def __getattr__(name: str) -> Any: + """Handle deprecated imports with warnings.""" + from neo4j_graphrag.utils.rate_limit import ( + RateLimitHandler, + NoOpRateLimitHandler, + RetryRateLimitHandler, + rate_limit_handler, + async_rate_limit_handler, + is_rate_limit_error, + convert_to_rate_limit_error, + DEFAULT_RATE_LIMIT_HANDLER, + ) + + deprecated_items = { + "RateLimitHandler": RateLimitHandler, + "NoOpRateLimitHandler": NoOpRateLimitHandler, + "RetryRateLimitHandler": RetryRateLimitHandler, + "rate_limit_handler": rate_limit_handler, + "async_rate_limit_handler": async_rate_limit_handler, + "is_rate_limit_error": is_rate_limit_error, + "convert_to_rate_limit_error": convert_to_rate_limit_error, + "DEFAULT_RATE_LIMIT_HANDLER": DEFAULT_RATE_LIMIT_HANDLER, + } + + if name in deprecated_items: + warnings.warn( + f"{name} has been moved to neo4j_graphrag.utils.rate_limit. " + f"Please update your imports to use 'from neo4j_graphrag.utils.rate_limit import {name}'.", + DeprecationWarning, + stacklevel=2, + ) + return deprecated_items[name] + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/neo4j_graphrag/llm/anthropic_llm.py b/src/neo4j_graphrag/llm/anthropic_llm.py index 6bafef85b..21560d3f2 100644 --- a/src/neo4j_graphrag/llm/anthropic_llm.py +++ b/src/neo4j_graphrag/llm/anthropic_llm.py @@ -19,7 +19,7 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface -from neo4j_graphrag.llm.rate_limit import ( +from neo4j_graphrag.utils.rate_limit import ( RateLimitHandler, rate_limit_handler, async_rate_limit_handler, diff --git a/src/neo4j_graphrag/llm/base.py b/src/neo4j_graphrag/llm/base.py index cca710bc9..ff7af1c70 100644 --- a/src/neo4j_graphrag/llm/base.py +++ b/src/neo4j_graphrag/llm/base.py @@ -21,13 +21,13 @@ from neo4j_graphrag.types import LLMMessage from .types import LLMResponse, ToolCallResponse -from .rate_limit import ( +from neo4j_graphrag.utils.rate_limit import ( DEFAULT_RATE_LIMIT_HANDLER, ) from neo4j_graphrag.tool import Tool -from .rate_limit import RateLimitHandler +from neo4j_graphrag.utils.rate_limit import RateLimitHandler class LLMInterface(ABC): diff --git a/src/neo4j_graphrag/llm/cohere_llm.py b/src/neo4j_graphrag/llm/cohere_llm.py index 7c3905500..2e3ca0cea 100644 --- a/src/neo4j_graphrag/llm/cohere_llm.py +++ b/src/neo4j_graphrag/llm/cohere_llm.py @@ -20,7 +20,7 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface -from neo4j_graphrag.llm.rate_limit import ( +from neo4j_graphrag.utils.rate_limit import ( RateLimitHandler, rate_limit_handler, async_rate_limit_handler, diff --git a/src/neo4j_graphrag/llm/mistralai_llm.py b/src/neo4j_graphrag/llm/mistralai_llm.py index ae2a6312f..3fa8663ae 100644 --- a/src/neo4j_graphrag/llm/mistralai_llm.py +++ b/src/neo4j_graphrag/llm/mistralai_llm.py @@ -21,7 +21,7 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface -from neo4j_graphrag.llm.rate_limit import ( +from neo4j_graphrag.utils.rate_limit import ( RateLimitHandler, rate_limit_handler, async_rate_limit_handler, diff --git a/src/neo4j_graphrag/llm/ollama_llm.py b/src/neo4j_graphrag/llm/ollama_llm.py index 214640625..94541e033 100644 --- a/src/neo4j_graphrag/llm/ollama_llm.py +++ b/src/neo4j_graphrag/llm/ollama_llm.py @@ -24,7 +24,11 @@ from neo4j_graphrag.types import LLMMessage from .base import LLMInterface -from .rate_limit import RateLimitHandler, rate_limit_handler, async_rate_limit_handler +from neo4j_graphrag.utils.rate_limit import ( + RateLimitHandler, + rate_limit_handler, + async_rate_limit_handler, +) from .types import ( BaseMessage, LLMResponse, diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index d74c83dcf..afdf0234d 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -35,7 +35,11 @@ from ..exceptions import LLMGenerationError from .base import LLMInterface -from .rate_limit import RateLimitHandler, rate_limit_handler, async_rate_limit_handler +from neo4j_graphrag.utils.rate_limit import ( + RateLimitHandler, + rate_limit_handler, + async_rate_limit_handler, +) from .types import ( BaseMessage, LLMResponse, diff --git a/src/neo4j_graphrag/llm/rate_limit.py b/src/neo4j_graphrag/llm/rate_limit.py index 098597f78..d2a82c595 100644 --- a/src/neo4j_graphrag/llm/rate_limit.py +++ b/src/neo4j_graphrag/llm/rate_limit.py @@ -12,243 +12,59 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations -import functools -import logging -from abc import ABC, abstractmethod -from typing import Any, Awaitable, Callable, TypeVar - -from neo4j_graphrag.exceptions import RateLimitError - -from tenacity import ( - retry, - stop_after_attempt, - wait_exponential, - wait_random_exponential, - retry_if_exception_type, - before_sleep_log, +""" +Deprecated module: Rate limiting functionality has been moved to neo4j_graphrag.utils.rate_limit. + +This module provides backward compatibility with deprecation warnings. +All new code should import from neo4j_graphrag.utils.rate_limit instead. +""" + +import warnings +from typing import Any + +# Import the actual implementations from the new location +from neo4j_graphrag.utils.rate_limit import ( + RateLimitHandler as _RateLimitHandler, + NoOpRateLimitHandler as _NoOpRateLimitHandler, + RetryRateLimitHandler as _RetryRateLimitHandler, + rate_limit_handler as _rate_limit_handler, + async_rate_limit_handler as _async_rate_limit_handler, + is_rate_limit_error as _is_rate_limit_error, + convert_to_rate_limit_error as _convert_to_rate_limit_error, + DEFAULT_RATE_LIMIT_HANDLER as _DEFAULT_RATE_LIMIT_HANDLER, ) -logger = logging.getLogger(__name__) - -F = TypeVar("F", bound=Callable[..., Any]) -AF = TypeVar("AF", bound=Callable[..., Awaitable[Any]]) - - -class RateLimitHandler(ABC): - """Abstract base class for rate limit handling strategies.""" - - @abstractmethod - def handle_sync(self, func: F) -> F: - """Apply rate limit handling to a synchronous function. - - Args: - func: The function to wrap with rate limit handling. - - Returns: - The wrapped function. - """ - pass - - @abstractmethod - def handle_async(self, func: AF) -> AF: - """Apply rate limit handling to an asynchronous function. - - Args: - func: The async function to wrap with rate limit handling. - - Returns: - The wrapped async function. - """ - pass - - -class NoOpRateLimitHandler(RateLimitHandler): - """A no-op rate limit handler that does not apply any rate limiting.""" - - def handle_sync(self, func: F) -> F: - """Return the function unchanged.""" - return func - - def handle_async(self, func: AF) -> AF: - """Return the async function unchanged.""" - return func - - -class RetryRateLimitHandler(RateLimitHandler): - """Rate limit handler using exponential backoff retry strategy. - - This handler uses tenacity for retry logic with exponential backoff. - - Args: - max_attempts: Maximum number of retry attempts. Defaults to 3. - min_wait: Minimum wait time between retries in seconds. Defaults to 1. - max_wait: Maximum wait time between retries in seconds. Defaults to 60. - multiplier: Exponential backoff multiplier. Defaults to 2. - jitter: Whether to add random jitter to retry delays to prevent thundering herd. Defaults to True. - """ - - def __init__( - self, - max_attempts: int = 3, - min_wait: float = 1.0, - max_wait: float = 60.0, - multiplier: float = 2.0, - jitter: bool = True, - ): - self.max_attempts = max_attempts - self.min_wait = min_wait - self.max_wait = max_wait - self.multiplier = multiplier - self.jitter = jitter - - def _get_wait_strategy(self) -> Any: - """Get the appropriate wait strategy based on jitter setting. - - Returns: - The configured wait strategy for tenacity retry. - """ - if self.jitter: - # Use built-in random exponential backoff with jitter - return wait_random_exponential( - multiplier=self.multiplier, - min=self.min_wait, - max=self.max_wait, - ) - else: - # Use standard exponential backoff without jitter - return wait_exponential( - multiplier=self.multiplier, - min=self.min_wait, - max=self.max_wait, - ) - - def handle_sync(self, func: F) -> F: - """Apply retry logic to a synchronous function.""" - decorator = retry( - retry=retry_if_exception_type(RateLimitError), - stop=stop_after_attempt(self.max_attempts), - wait=self._get_wait_strategy(), - before_sleep=before_sleep_log(logger, logging.WARNING), - ) - return decorator(func) - - def handle_async(self, func: AF) -> AF: - """Apply retry logic to an asynchronous function.""" - decorator = retry( - retry=retry_if_exception_type(RateLimitError), - stop=stop_after_attempt(self.max_attempts), - wait=self._get_wait_strategy(), - before_sleep=before_sleep_log(logger, logging.WARNING), - ) - return decorator(func) - - -def is_rate_limit_error(exception: Exception) -> bool: - """Check if an exception is a rate limit error from any LLM provider. - - Args: - exception: The exception to check. - - Returns: - True if the exception indicates a rate limit error, False otherwise. - """ - error_type = type(exception).__name__.lower() - exception_str = str(exception).lower() - - # For LLMGenerationError (which wraps all provider errors), check provider-specific patterns - if error_type == "llmgenerationerror": - # Check for various rate limit patterns from different providers - rate_limit_patterns = [ - "error code: 429", # Azure OpenAI - "too many requests", # Anthropic, Cohere, MistralAI - "resource exhausted", # VertexAI - "rate limit", # Generic rate limit messages - "429", # Generic rate limit messages - ] - - return any(pattern in exception_str for pattern in rate_limit_patterns) - - return False - - -def convert_to_rate_limit_error(exception: Exception) -> RateLimitError: - """Convert a provider-specific rate limit exception to RateLimitError. - - Args: - exception: The original exception from the LLM provider. - - Returns: - A RateLimitError with the original exception message. - """ - return RateLimitError(f"Rate limit exceeded: {exception}") - - -def rate_limit_handler(func: F) -> F: - """Decorator to apply rate limit handling to synchronous methods. - - This decorator works with instance methods and uses the instance's rate limit handler. - - Args: - func: The function to wrap with rate limit handling. - - Returns: - The wrapped function. - """ - - @functools.wraps(func) - def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - # Use instance handler or default - active_handler = getattr( - self, "_rate_limit_handler", DEFAULT_RATE_LIMIT_HANDLER - ) - - def inner_func() -> Any: - try: - return func(self, *args, **kwargs) - except Exception as e: - if is_rate_limit_error(e): - raise convert_to_rate_limit_error(e) - raise - - return active_handler.handle_sync(inner_func)() - - return wrapper # type: ignore - - -def async_rate_limit_handler(func: AF) -> AF: - """Decorator to apply rate limit handling to asynchronous methods. - - This decorator works with instance methods and uses the instance's rate limit handler. - - Args: - func: The async function to wrap with rate limit handling. - - Returns: - The wrapped async function. - """ - - @functools.wraps(func) - async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - # Use instance handler or default - active_handler = getattr( - self, "_rate_limit_handler", DEFAULT_RATE_LIMIT_HANDLER +def __getattr__(name: str) -> Any: + """Handle deprecated imports with warnings.""" + deprecated_items = { + "RateLimitHandler": _RateLimitHandler, + "NoOpRateLimitHandler": _NoOpRateLimitHandler, + "RetryRateLimitHandler": _RetryRateLimitHandler, + "rate_limit_handler": _rate_limit_handler, + "async_rate_limit_handler": _async_rate_limit_handler, + "is_rate_limit_error": _is_rate_limit_error, + "convert_to_rate_limit_error": _convert_to_rate_limit_error, + "DEFAULT_RATE_LIMIT_HANDLER": _DEFAULT_RATE_LIMIT_HANDLER, + } + + if name in deprecated_items: + warnings.warn( + f"{name} has been moved to neo4j_graphrag.utils.rate_limit. " + f"Please update your imports to use 'from neo4j_graphrag.utils.rate_limit import {name}'.", + DeprecationWarning, + stacklevel=2, ) + return deprecated_items[name] - async def inner_func() -> Any: - try: - return await func(self, *args, **kwargs) - except Exception as e: - if is_rate_limit_error(e): - raise convert_to_rate_limit_error(e) - raise + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - return await active_handler.handle_async(inner_func)() - return wrapper # type: ignore - - -# Default rate limit handler instance -DEFAULT_RATE_LIMIT_HANDLER = RetryRateLimitHandler() +# Issue a single deprecation warning when the module is imported +warnings.warn( + "The neo4j_graphrag.llm.rate_limit module has been moved to neo4j_graphrag.utils.rate_limit. " + "Please update your imports to use 'from neo4j_graphrag.utils.rate_limit import ...'.", + DeprecationWarning, + stacklevel=2, +) diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index 0b4926978..b9f1e40e8 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -19,7 +19,7 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface -from neo4j_graphrag.llm.rate_limit import ( +from neo4j_graphrag.utils.rate_limit import ( RateLimitHandler, rate_limit_handler, async_rate_limit_handler, diff --git a/src/neo4j_graphrag/utils/rate_limit.py b/src/neo4j_graphrag/utils/rate_limit.py new file mode 100644 index 000000000..165de21ee --- /dev/null +++ b/src/neo4j_graphrag/utils/rate_limit.py @@ -0,0 +1,254 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import functools +import logging +from abc import ABC, abstractmethod +from typing import Any, Awaitable, Callable, TypeVar + +from neo4j_graphrag.exceptions import RateLimitError + +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + wait_random_exponential, + retry_if_exception_type, + before_sleep_log, +) + + +logger = logging.getLogger(__name__) + +F = TypeVar("F", bound=Callable[..., Any]) +AF = TypeVar("AF", bound=Callable[..., Awaitable[Any]]) + + +class RateLimitHandler(ABC): + """Abstract base class for rate limit handling strategies.""" + + @abstractmethod + def handle_sync(self, func: F) -> F: + """Apply rate limit handling to a synchronous function. + + Args: + func: The function to wrap with rate limit handling. + + Returns: + The wrapped function. + """ + pass + + @abstractmethod + def handle_async(self, func: AF) -> AF: + """Apply rate limit handling to an asynchronous function. + + Args: + func: The async function to wrap with rate limit handling. + + Returns: + The wrapped async function. + """ + pass + + +class NoOpRateLimitHandler(RateLimitHandler): + """A no-op rate limit handler that does not apply any rate limiting.""" + + def handle_sync(self, func: F) -> F: + """Return the function unchanged.""" + return func + + def handle_async(self, func: AF) -> AF: + """Return the async function unchanged.""" + return func + + +class RetryRateLimitHandler(RateLimitHandler): + """Rate limit handler using exponential backoff retry strategy. + + This handler uses tenacity for retry logic with exponential backoff. + + Args: + max_attempts: Maximum number of retry attempts. Defaults to 3. + min_wait: Minimum wait time between retries in seconds. Defaults to 1. + max_wait: Maximum wait time between retries in seconds. Defaults to 60. + multiplier: Exponential backoff multiplier. Defaults to 2. + jitter: Whether to add random jitter to retry delays to prevent thundering herd. Defaults to True. + """ + + def __init__( + self, + max_attempts: int = 3, + min_wait: float = 1.0, + max_wait: float = 60.0, + multiplier: float = 2.0, + jitter: bool = True, + ): + self.max_attempts = max_attempts + self.min_wait = min_wait + self.max_wait = max_wait + self.multiplier = multiplier + self.jitter = jitter + + def _get_wait_strategy(self) -> Any: + """Get the appropriate wait strategy based on jitter setting. + + Returns: + The configured wait strategy for tenacity retry. + """ + if self.jitter: + # Use built-in random exponential backoff with jitter + return wait_random_exponential( + multiplier=self.multiplier, + min=self.min_wait, + max=self.max_wait, + ) + else: + # Use standard exponential backoff without jitter + return wait_exponential( + multiplier=self.multiplier, + min=self.min_wait, + max=self.max_wait, + ) + + def handle_sync(self, func: F) -> F: + """Apply retry logic to a synchronous function.""" + decorator = retry( + retry=retry_if_exception_type(RateLimitError), + stop=stop_after_attempt(self.max_attempts), + wait=self._get_wait_strategy(), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + return decorator(func) + + def handle_async(self, func: AF) -> AF: + """Apply retry logic to an asynchronous function.""" + decorator = retry( + retry=retry_if_exception_type(RateLimitError), + stop=stop_after_attempt(self.max_attempts), + wait=self._get_wait_strategy(), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + return decorator(func) + + +def is_rate_limit_error(exception: Exception) -> bool: + """Check if an exception is a rate limit error from any LLM provider or embedder. + + Args: + exception: The exception to check. + + Returns: + True if the exception indicates a rate limit error, False otherwise. + """ + error_type = type(exception).__name__.lower() + exception_str = str(exception).lower() + + # For LLMGenerationError or EmbeddingsGenerationError (which wrap all provider errors), check provider-specific patterns + if error_type in ["llmgenerationerror", "embeddingsgenerationerror"]: + # Check for various rate limit patterns from different providers + rate_limit_patterns = [ + "error code: 429", # Azure OpenAI + "too many requests", # Anthropic, Cohere, MistralAI + "resource exhausted", # VertexAI + "rate limit", # Generic rate limit messages + "429", # Generic rate limit messages + ] + + return any(pattern in exception_str for pattern in rate_limit_patterns) + + return False + + +def convert_to_rate_limit_error(exception: Exception) -> RateLimitError: + """Convert a provider-specific rate limit exception to RateLimitError. + + Args: + exception: The original exception from the LLM provider. + + Returns: + A RateLimitError with the original exception message. + """ + return RateLimitError(f"Rate limit exceeded: {exception}") + + +def rate_limit_handler(func: F) -> F: + """Decorator to apply rate limit handling to synchronous methods. + + This decorator works with instance methods and uses the instance's rate limit handler. + + Args: + func: The function to wrap with rate limit handling. + + Returns: + The wrapped function. + """ + + @functools.wraps(func) + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + # Use instance handler or default + active_handler = getattr( + self, "_rate_limit_handler", DEFAULT_RATE_LIMIT_HANDLER + ) + + def inner_func() -> Any: + try: + return func(self, *args, **kwargs) + except Exception as e: + if is_rate_limit_error(e): + raise convert_to_rate_limit_error(e) + raise + + return active_handler.handle_sync(inner_func)() + + return wrapper # type: ignore + + +def async_rate_limit_handler(func: AF) -> AF: + """Decorator to apply rate limit handling to asynchronous methods. + + This decorator works with instance methods and uses the instance's rate limit handler. + + Args: + func: The async function to wrap with rate limit handling. + + Returns: + The wrapped async function. + """ + + @functools.wraps(func) + async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + # Use instance handler or default + active_handler = getattr( + self, "_rate_limit_handler", DEFAULT_RATE_LIMIT_HANDLER + ) + + async def inner_func() -> Any: + try: + return await func(self, *args, **kwargs) + except Exception as e: + if is_rate_limit_error(e): + raise convert_to_rate_limit_error(e) + raise + + return await active_handler.handle_async(inner_func)() + + return wrapper # type: ignore + + +# Default rate limit handler instance +DEFAULT_RATE_LIMIT_HANDLER = RetryRateLimitHandler() diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 9932f12e5..4ae462a12 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -58,12 +58,12 @@ def embedder() -> Embedder: class RandomEmbedder(Embedder): - def embed_query(self, text: str) -> list[float]: + def _embed_query(self, text: str) -> list[float]: return [random.random() for _ in range(1536)] class BiologyEmbedder(Embedder): - def embed_query(self, text: str) -> list[float]: + def _embed_query(self, text: str) -> list[float]: if text == "biology": return EMBEDDING_BIOLOGY raise ValueError(f"Unknown embedding text: {text}") diff --git a/tests/unit/embeddings/test_cohere_embedder.py b/tests/unit/embeddings/test_cohere_embedder.py index 244e90c52..af87f4a24 100644 --- a/tests/unit/embeddings/test_cohere_embedder.py +++ b/tests/unit/embeddings/test_cohere_embedder.py @@ -15,7 +15,9 @@ from unittest.mock import MagicMock, Mock, patch import pytest +from tenacity import RetryError from neo4j_graphrag.embeddings.cohere import CohereEmbeddings +from neo4j_graphrag.exceptions import EmbeddingsGenerationError @patch("neo4j_graphrag.embeddings.cohere.cohere", None) @@ -32,3 +34,59 @@ def test_cohere_embedder_happy_path(mock_cohere: Mock) -> None: embedder = CohereEmbeddings() res = embedder.embed_query("my text") assert res == [1.0, 2.0] + + +@patch("neo4j_graphrag.embeddings.cohere.cohere") +def test_cohere_embedder_non_retryable_error_handling(mock_cohere: Mock) -> None: + """Test that non-retryable errors fail immediately without retries.""" + mock_embeddings = mock_cohere.Client.return_value.embed + mock_embeddings.side_effect = Exception("API Error") + embedder = CohereEmbeddings() + with pytest.raises( + EmbeddingsGenerationError, match="Failed to generate embedding with Cohere" + ): + embedder.embed_query("my text") + + # Verify the API was called only once (no retries for non-rate-limit errors) + assert mock_embeddings.call_count == 1 + + +@patch("neo4j_graphrag.embeddings.cohere.cohere") +def test_cohere_embedder_rate_limit_error_retries(mock_cohere: Mock) -> None: + """Test that rate limit errors are retried the expected number of times.""" + # Rate limit error that should trigger retries (matches "too many requests" pattern) + # Create separate exception instances for each retry attempt + mock_embeddings = mock_cohere.Client.return_value.embed + mock_embeddings.side_effect = [ + Exception("too many requests - please try again later"), + Exception("too many requests - please try again later"), + Exception("too many requests - please try again later"), + ] + embedder = CohereEmbeddings() + + # After exhausting retries, tenacity raises RetryError + with pytest.raises(RetryError): + embedder.embed_query("my text") + + # Verify the API was called 3 times (default max_attempts for RetryRateLimitHandler) + assert mock_cohere.Client.return_value.embed.call_count == 3 + + +@patch("neo4j_graphrag.embeddings.cohere.cohere") +def test_cohere_embedder_rate_limit_error_eventual_success(mock_cohere: Mock) -> None: + """Test that rate limit errors eventually succeed after retries.""" + # First two calls fail with rate limit, third succeeds + mock_embeddings = mock_cohere.Client.return_value.embed + mock_embeddings.side_effect = [ + Exception("too many requests - please try again later"), + Exception("too many requests - please try again later"), + MagicMock(embeddings=[[1.0, 2.0]]), + ] + embedder = CohereEmbeddings() + + result = embedder.embed_query("my text") + + # Verify successful result + assert result == [1.0, 2.0] + # Verify the API was called 3 times before succeeding + assert mock_embeddings.call_count == 3 diff --git a/tests/unit/embeddings/test_mistralai_embedder.py b/tests/unit/embeddings/test_mistralai_embedder.py index f962f7c37..f2fd4cbb4 100644 --- a/tests/unit/embeddings/test_mistralai_embedder.py +++ b/tests/unit/embeddings/test_mistralai_embedder.py @@ -15,7 +15,9 @@ from unittest.mock import MagicMock, Mock, patch import pytest +from tenacity import RetryError from neo4j_graphrag.embeddings import MistralAIEmbeddings +from neo4j_graphrag.exceptions import EmbeddingsGenerationError @patch("neo4j_graphrag.embeddings.mistral.Mistral", None) @@ -72,3 +74,72 @@ def test_mistralai_embedder_api_key_from_env( mock_getenv.assert_called_with("MISTRAL_API_KEY", "") mock_mistral.assert_called_with(api_key="env_api_key") + + +@patch("neo4j_graphrag.embeddings.mistral.Mistral") +def test_mistralai_embedder_non_retryable_error_handling(mock_mistral: Mock) -> None: + """Test that non-retryable errors fail immediately without retries.""" + mock_mistral_instance = mock_mistral.return_value + mock_embeddings = mock_mistral_instance.embeddings.create + mock_embeddings.side_effect = Exception("API Error") + + embedder = MistralAIEmbeddings() + # MistralAI now wraps exceptions, so we expect EmbeddingsGenerationError + with pytest.raises( + EmbeddingsGenerationError, match="Failed to generate embedding with MistralAI" + ): + embedder.embed_query("my text") + + # Verify the API was called only once (no retries for non-rate-limit errors) + assert mock_embeddings.call_count == 1 + + +@patch("neo4j_graphrag.embeddings.mistral.Mistral") +def test_mistralai_embedder_rate_limit_error_retries(mock_mistral: Mock) -> None: + """Test that rate limit errors are retried the expected number of times.""" + mock_mistral_instance = mock_mistral.return_value + + # Rate limit error that should trigger retries (matches "too many requests" pattern) + # Create separate exception instances for each retry attempt + mock_embeddings = mock_mistral_instance.embeddings.create + mock_embeddings.side_effect = [ + Exception("too many requests - rate limit exceeded"), + Exception("too many requests - rate limit exceeded"), + Exception("too many requests - rate limit exceeded"), + ] + + embedder = MistralAIEmbeddings() + + # After exhausting retries, tenacity raises RetryError + with pytest.raises(RetryError): + embedder.embed_query("my text") + + # Verify the API was called 3 times (default max_attempts for RetryRateLimitHandler) + assert mock_embeddings.call_count == 3 + + +@patch("neo4j_graphrag.embeddings.mistral.Mistral") +def test_mistralai_embedder_rate_limit_error_eventual_success( + mock_mistral: Mock, +) -> None: + """Test that rate limit errors eventually succeed after retries.""" + mock_mistral_instance = mock_mistral.return_value + + # First two calls fail with rate limit, third succeeds + embeddings_batch_response_mock = MagicMock() + embeddings_batch_response_mock.data = [MagicMock(embedding=[1.0, 2.0])] + + mock_embeddings = mock_mistral_instance.embeddings.create + mock_embeddings.side_effect = [ + Exception("too many requests - rate limit exceeded"), + Exception("too many requests - rate limit exceeded"), + embeddings_batch_response_mock, + ] + + embedder = MistralAIEmbeddings() + result = embedder.embed_query("my text") + + # Verify successful result + assert result == [1.0, 2.0] + # Verify the API was called 3 times before succeeding + assert mock_embeddings.call_count == 3 diff --git a/tests/unit/embeddings/test_openai_embedder.py b/tests/unit/embeddings/test_openai_embedder.py index a1b940f04..cdfc3d6c5 100644 --- a/tests/unit/embeddings/test_openai_embedder.py +++ b/tests/unit/embeddings/test_openai_embedder.py @@ -16,10 +16,12 @@ import openai import pytest +from tenacity import RetryError from neo4j_graphrag.embeddings.openai import ( AzureOpenAIEmbeddings, OpenAIEmbeddings, ) +from neo4j_graphrag.exceptions import EmbeddingsGenerationError def get_mock_openai() -> MagicMock: @@ -92,3 +94,70 @@ def test_azure_openai_embedder_does_not_call_openai_client() -> None: api_key="my_key", api_version="2023-05-15", ) + + +@patch("builtins.__import__") +def test_openai_embedder_non_retryable_error_handling(mock_import: Mock) -> None: + """Test that non-retryable errors fail immediately without retries.""" + mock_openai = get_mock_openai() + mock_import.return_value = mock_openai + + # Generic API error that doesn't match rate limit patterns - should not be retried + mock_embeddings = mock_openai.OpenAI.return_value.embeddings.create + mock_embeddings.side_effect = Exception("API Error") + embedder = OpenAIEmbeddings(api_key="my key") + + with pytest.raises( + EmbeddingsGenerationError, match="Failed to generate embedding with OpenAI" + ): + embedder.embed_query("my text") + + # Verify the API was called only once (no retries for non-rate-limit errors) + assert mock_embeddings.call_count == 1 + + +@patch("builtins.__import__") +def test_openai_embedder_rate_limit_error_retries(mock_import: Mock) -> None: + """Test that rate limit errors are retried the expected number of times.""" + mock_openai = get_mock_openai() + mock_import.return_value = mock_openai + + # Rate limit error that should trigger retries (matches "429" pattern) + # Create separate exception instances for each retry attempt + mock_embeddings = mock_openai.OpenAI.return_value.embeddings.create + mock_embeddings.side_effect = [ + Exception("Error code: 429 - Too many requests"), + Exception("Error code: 429 - Too many requests"), + Exception("Error code: 429 - Too many requests"), + ] + embedder = OpenAIEmbeddings(api_key="my key") + + # After exhausting retries, tenacity raises RetryError + with pytest.raises(RetryError): + embedder.embed_query("my text") + + # Verify the API was called 3 times (default max_attempts for RetryRateLimitHandler) + assert mock_embeddings.call_count == 3 + + +@patch("builtins.__import__") +def test_openai_embedder_rate_limit_error_eventual_success(mock_import: Mock) -> None: + """Test that rate limit errors eventually succeed after retries.""" + mock_openai = get_mock_openai() + mock_import.return_value = mock_openai + + # First two calls fail with rate limit, third succeeds + mock_embeddings = mock_openai.OpenAI.return_value.embeddings.create + mock_embeddings.side_effect = [ + Exception("Error code: 429 - Too many requests"), + Exception("Error code: 429 - Too many requests"), + MagicMock(data=[MagicMock(embedding=[1.0, 2.0])]), + ] + embedder = OpenAIEmbeddings(api_key="my key") + + result = embedder.embed_query("my text") + + # Verify successful result + assert result == [1.0, 2.0] + # Verify the API was called 3 times before succeeding + assert mock_embeddings.call_count == 3 diff --git a/tests/unit/embeddings/test_sentence_transformers.py b/tests/unit/embeddings/test_sentence_transformers.py index 197095e49..f9c8f36bc 100644 --- a/tests/unit/embeddings/test_sentence_transformers.py +++ b/tests/unit/embeddings/test_sentence_transformers.py @@ -3,10 +3,12 @@ import numpy as np import pytest import torch +from tenacity import RetryError from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.embeddings.sentence_transformers import ( SentenceTransformerEmbeddings, ) +from neo4j_graphrag.exceptions import EmbeddingsGenerationError def get_mock_sentence_transformers() -> MagicMock: @@ -55,3 +57,70 @@ def test_embed_query(mock_import: Mock) -> None: def test_import_error(mock_import: Mock) -> None: with pytest.raises(ImportError): SentenceTransformerEmbeddings() + + +@patch("builtins.__import__") +def test_embed_query_non_retryable_error_handling(mock_import: Mock) -> None: + """Test that non-retryable errors fail immediately without retries.""" + MockSentenceTransformer = get_mock_sentence_transformers() + mock_import.return_value = MockSentenceTransformer + mock_model = MockSentenceTransformer.SentenceTransformer.return_value + mock_model.encode.side_effect = Exception("Model error") + + instance = SentenceTransformerEmbeddings() + with pytest.raises( + EmbeddingsGenerationError, + match="Failed to generate embedding with SentenceTransformer", + ): + instance.embed_query("test query") + + # Verify the model was called only once (no retries for non-rate-limit errors) + assert mock_model.encode.call_count == 1 + + +@patch("builtins.__import__") +def test_embed_query_rate_limit_error_retries(mock_import: Mock) -> None: + """Test that rate limit errors are retried the expected number of times.""" + MockSentenceTransformer = get_mock_sentence_transformers() + mock_import.return_value = MockSentenceTransformer + mock_model = MockSentenceTransformer.SentenceTransformer.return_value + + # Rate limit error that should trigger retries (matches "too many requests" pattern) + # Create separate exception instances for each retry attempt + mock_model.encode.side_effect = [ + Exception("too many requests - please wait"), + Exception("too many requests - please wait"), + Exception("too many requests - please wait"), + ] + + instance = SentenceTransformerEmbeddings() + + # After exhausting retries, tenacity raises RetryError (since retries should work) + with pytest.raises(RetryError): + instance.embed_query("test query") + + # Verify the model was called 3 times (default max_attempts for RetryRateLimitHandler) + assert mock_model.encode.call_count == 3 + + +@patch("builtins.__import__") +def test_embed_query_rate_limit_error_eventual_success(mock_import: Mock) -> None: + """Test that rate limit errors eventually succeed after retries.""" + MockSentenceTransformer = get_mock_sentence_transformers() + mock_import.return_value = MockSentenceTransformer + mock_model = MockSentenceTransformer.SentenceTransformer.return_value + + # First two calls fail with rate limit, third succeeds + mock_model.encode.side_effect = [ + Exception("too many requests - please wait"), + Exception("too many requests - please wait"), + np.array([[0.1, 0.2, 0.3]]), + ] + + instance = SentenceTransformerEmbeddings() + result = instance.embed_query("test query") + + # Verify successful result + assert result == [0.1, 0.2, 0.3] + # Verify the model was called 3 times before succeeding + assert mock_model.encode.call_count == 3 diff --git a/tests/unit/embeddings/test_vertexai_embedder.py b/tests/unit/embeddings/test_vertexai_embedder.py index 018230a67..790d7e402 100644 --- a/tests/unit/embeddings/test_vertexai_embedder.py +++ b/tests/unit/embeddings/test_vertexai_embedder.py @@ -15,7 +15,9 @@ from unittest.mock import MagicMock, Mock, patch import pytest +from tenacity import RetryError from neo4j_graphrag.embeddings.vertexai import VertexAIEmbeddings +from neo4j_graphrag.exceptions import EmbeddingsGenerationError @patch("neo4j_graphrag.embeddings.vertexai.TextEmbeddingModel", None) @@ -33,3 +35,60 @@ def test_vertexai_embedder_happy_path(mock_vertexai: Mock) -> None: res = embedder.embed_query("my text") assert isinstance(res, list) assert res == [1.0, 2.0] + + +@patch("neo4j_graphrag.embeddings.vertexai.TextEmbeddingModel") +def test_vertexai_embedder_non_retryable_error_handling(mock_vertexai: Mock) -> None: + """Test that non-retryable errors fail immediately without retries.""" + mock_embeddings = mock_vertexai.from_pretrained.return_value.get_embeddings + mock_embeddings.side_effect = Exception("API Error") + embedder = VertexAIEmbeddings() + with pytest.raises( + EmbeddingsGenerationError, match="Failed to generate embedding with VertexAI" + ): + embedder.embed_query("my text") + + # Verify the API was called only once (no retries for non-rate-limit errors) + assert mock_embeddings.call_count == 1 + + +@patch("neo4j_graphrag.embeddings.vertexai.TextEmbeddingModel") +def test_vertexai_embedder_rate_limit_error_retries(mock_vertexai: Mock) -> None: + """Test that rate limit errors are retried the expected number of times.""" + # Rate limit error that should trigger retries (matches "resource exhausted" pattern) + mock_embeddings = mock_vertexai.from_pretrained.return_value.get_embeddings + mock_embeddings.side_effect = [ + Exception("resource exhausted - quota exceeded"), + Exception("resource exhausted - quota exceeded"), + Exception("resource exhausted - quota exceeded"), + ] + embedder = VertexAIEmbeddings() + + # After exhausting retries, tenacity raises RetryError + with pytest.raises(RetryError): + embedder.embed_query("my text") + + # Verify the API was called 3 times (default max_attempts for RetryRateLimitHandler) + assert mock_embeddings.call_count == 3 + + +@patch("neo4j_graphrag.embeddings.vertexai.TextEmbeddingModel") +def test_vertexai_embedder_rate_limit_error_eventual_success( + mock_vertexai: Mock, +) -> None: + """Test that rate limit errors eventually succeed after retries.""" + # First two calls fail with rate limit, third succeeds + mock_embeddings = mock_vertexai.from_pretrained.return_value.get_embeddings + mock_embeddings.side_effect = [ + Exception("resource exhausted - quota exceeded"), + Exception("resource exhausted - quota exceeded"), + [MagicMock(values=[1.0, 2.0])], + ] + embedder = VertexAIEmbeddings() + + result = embedder.embed_query("my text") + + # Verify successful result + assert result == [1.0, 2.0] + # Verify the API was called 3 times before succeeding + assert mock_embeddings.call_count == 3 diff --git a/tests/unit/llm/test_rate_limit.py b/tests/unit/llm/test_rate_limit.py index f1f4b133b..51bc80b14 100644 --- a/tests/unit/llm/test_rate_limit.py +++ b/tests/unit/llm/test_rate_limit.py @@ -19,7 +19,7 @@ from unittest.mock import Mock from tenacity import RetryError -from neo4j_graphrag.llm.rate_limit import ( +from neo4j_graphrag.utils.rate_limit import ( RateLimitHandler, NoOpRateLimitHandler, DEFAULT_RATE_LIMIT_HANDLER,