diff --git a/packages/leann-backend-diskann/leann_backend_diskann/diskann_embedding_server.py b/packages/leann-backend-diskann/leann_backend_diskann/diskann_embedding_server.py index 592fddb2..1c901b50 100644 --- a/packages/leann-backend-diskann/leann_backend_diskann/diskann_embedding_server.py +++ b/packages/leann-backend-diskann/leann_backend_diskann/diskann_embedding_server.py @@ -48,13 +48,23 @@ def create_diskann_embedding_server( model_name: str = "sentence-transformers/all-mpnet-base-v2", embedding_mode: str = "sentence-transformers", distance_metric: str = "l2", + enable_warmup: bool = True, ): """ Create and start a ZMQ-based embedding server for DiskANN backend. Uses ROUTER socket and protobuf communication as required by DiskANN C++ implementation. + + Args: + passages_file: Path to the metadata file (.meta.json) + zmq_port: Port for ZMQ server + model_name: Name of the embedding model to use + embedding_mode: Embedding backend mode + distance_metric: Distance metric (l2, mips, cosine) + enable_warmup: If True, pre-load model and run warmup embedding on startup """ logger.info(f"Starting DiskANN server on port {zmq_port} with model {model_name}") logger.info(f"Using embedding mode: {embedding_mode}") + logger.info(f"Warmup enabled: {enable_warmup}") # Add leann-core to path for unified embedding computation current_dir = Path(__file__).parent @@ -72,6 +82,24 @@ def create_diskann_embedding_server( finally: sys.path.pop(0) + # Warmup: Pre-load the embedding model by computing a dummy embedding + # This ensures the model is cached and ready for fast subsequent queries + if enable_warmup: + warmup_start = time.time() + logger.info("Starting model warmup...") + try: + # Compute a dummy embedding to trigger model loading and caching + _ = compute_embeddings( + ["warmup query for model preloading"], + model_name, + mode=embedding_mode, + provider_options=PROVIDER_OPTIONS, + ) + warmup_time = time.time() - warmup_start + logger.info(f"Model warmup completed in {warmup_time:.2f}s") + except Exception as e: + logger.warning(f"Model warmup failed (non-fatal): {e}") + # Check port availability import socket @@ -479,9 +507,23 @@ def signal_handler(sig, frame): choices=["l2", "mips", "cosine"], help="Distance metric for similarity computation", ) + parser.add_argument( + "--enable-warmup", + action="store_true", + default=True, + help="Pre-load embedding model on startup for faster first query (default: True)", + ) + parser.add_argument( + "--no-warmup", + action="store_true", + help="Disable warmup (lazy model loading)", + ) args = parser.parse_args() + # Determine warmup setting (--no-warmup takes precedence) + enable_warmup = not args.no_warmup + # Create and start the DiskANN embedding server create_diskann_embedding_server( passages_file=args.passages_file, @@ -489,4 +531,5 @@ def signal_handler(sig, frame): model_name=args.model_name, embedding_mode=args.embedding_mode, distance_metric=args.distance_metric, + enable_warmup=enable_warmup, ) diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py index 882acbf7..0a2925d3 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py @@ -61,13 +61,23 @@ def create_hnsw_embedding_server( model_name: str = "sentence-transformers/all-mpnet-base-v2", distance_metric: str = "mips", embedding_mode: str = "sentence-transformers", + enable_warmup: bool = True, ): """ Create and start a ZMQ-based embedding server for HNSW backend. Simplified version using unified embedding computation module. + + Args: + passages_file: Path to the metadata file (.meta.json) + zmq_port: Port for ZMQ server + model_name: Name of the embedding model to use + distance_metric: Distance metric (mips, l2, cosine) + embedding_mode: Embedding backend mode + enable_warmup: If True, pre-load model and run warmup embedding on startup """ logger.info(f"Starting HNSW server on port {zmq_port} with model {model_name}") logger.info(f"Using embedding mode: {embedding_mode}") + logger.info(f"Warmup enabled: {enable_warmup}") # Add leann-core to path for unified embedding computation current_dir = Path(__file__).parent @@ -85,6 +95,24 @@ def create_hnsw_embedding_server( finally: sys.path.pop(0) + # Warmup: Pre-load the embedding model by computing a dummy embedding + # This ensures the model is cached and ready for fast subsequent queries + if enable_warmup: + warmup_start = time.time() + logger.info("Starting model warmup...") + try: + # Compute a dummy embedding to trigger model loading and caching + _ = compute_embeddings( + ["warmup query for model preloading"], + model_name, + mode=embedding_mode, + provider_options=PROVIDER_OPTIONS, + ) + warmup_time = time.time() - warmup_start + logger.info(f"Model warmup completed in {warmup_time:.2f}s") + except Exception as e: + logger.warning(f"Model warmup failed (non-fatal): {e}") + # Check port availability import socket @@ -481,9 +509,23 @@ def signal_handler(sig, frame): choices=["sentence-transformers", "openai", "mlx", "ollama"], help="Embedding backend mode", ) + parser.add_argument( + "--enable-warmup", + action="store_true", + default=True, + help="Pre-load embedding model on startup for faster first query (default: True)", + ) + parser.add_argument( + "--no-warmup", + action="store_true", + help="Disable warmup (lazy model loading)", + ) args = parser.parse_args() + # Determine warmup setting (--no-warmup takes precedence) + enable_warmup = not args.no_warmup + # Create and start the HNSW embedding server create_hnsw_embedding_server( passages_file=args.passages_file, @@ -491,4 +533,5 @@ def signal_handler(sig, frame): model_name=args.model_name, distance_metric=args.distance_metric, embedding_mode=args.embedding_mode, + enable_warmup=enable_warmup, ) diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index d64d4335..7e672673 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -865,6 +865,15 @@ def update_index(self, index_path: str): class LeannSearcher: def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs): + """Initialize a LeannSearcher for searching an existing index. + + Args: + index_path: Path to the .leann index file + enable_warmup: If True, pre-load the embedding model on initialization + for faster first search. This adds initialization time + but reduces latency on the first search query. + **backend_kwargs: Additional arguments passed to the backend searcher + """ # Fix path resolution for Colab and other environments if not Path(index_path).is_absolute(): index_path = str(Path(index_path).resolve()) @@ -903,6 +912,12 @@ def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwarg index_path, **final_kwargs ) + # Auto-warmup if requested - this pre-loads the embedding model + # to avoid cold-start latency on the first search + self._warmup_enabled = enable_warmup + if enable_warmup: + self.warmup() + def search( self, query: str, @@ -1165,6 +1180,60 @@ def _python_regex_search(self, query: str, top_k: int = 5) -> list[SearchResult] matches.sort(key=lambda x: x.score, reverse=True) return matches[:top_k] + def warmup(self, port: int = 5557) -> float: + """Pre-warm the embedding server and model for faster subsequent searches. + + This method starts the embedding server (if not already running) and + ensures the embedding model is loaded and cached. Call this before + your first search to avoid cold-start latency. + + Args: + port: ZMQ port for the embedding server (default: 5557) + + Returns: + Time taken for warmup in seconds + + Example: + >>> searcher = LeannSearcher("path/to/index.leann") + >>> warmup_time = searcher.warmup() + >>> print(f"Warmup completed in {warmup_time:.2f}s") + >>> # Subsequent searches will be faster + >>> results = searcher.search("my query") + """ + import time + + start_time = time.time() + logger.info("Starting LeannSearcher warmup...") + + try: + # Start the embedding server with warmup enabled + # This triggers model loading in the server process + zmq_port = self.backend_impl._ensure_server_running( + self.meta_path_str, + port=port, + enable_warmup=True, + ) + + # Optionally, do a dummy query to ensure everything is fully warmed up + # This tests the full path including ZMQ communication + try: + _ = self.backend_impl.compute_query_embedding( + "warmup test query", + use_server_if_available=True, + zmq_port=zmq_port, + ) + except Exception as e: + logger.warning(f"Warmup query failed (non-fatal): {e}") + + warmup_time = time.time() - start_time + logger.info(f"LeannSearcher warmup completed in {warmup_time:.2f}s") + return warmup_time + + except Exception as e: + warmup_time = time.time() - start_time + logger.warning(f"Warmup partially failed after {warmup_time:.2f}s: {e}") + return warmup_time + def cleanup(self): """Explicitly cleanup embedding server resources. This method should be called after you're done using the searcher, diff --git a/packages/leann-core/src/leann/embedding_server_manager.py b/packages/leann-core/src/leann/embedding_server_manager.py index ca61d053..ea064ab8 100644 --- a/packages/leann-core/src/leann/embedding_server_manager.py +++ b/packages/leann-core/src/leann/embedding_server_manager.py @@ -337,6 +337,9 @@ def _build_server_command( command.extend(["--embedding-mode", embedding_mode]) if kwargs.get("distance_metric"): command.extend(["--distance-metric", kwargs["distance_metric"]]) + # Control warmup behavior - default is enabled, use --no-warmup to disable + if not kwargs.get("enable_warmup", True): + command.append("--no-warmup") return command diff --git a/tests/test_warmup.py b/tests/test_warmup.py new file mode 100644 index 00000000..a520be59 --- /dev/null +++ b/tests/test_warmup.py @@ -0,0 +1,222 @@ +""" +Tests for warmup functionality to reduce search latency. + +These tests verify that: +1. The warmup() method can be called on LeannSearcher +2. enable_warmup=True causes auto-warmup during initialization +3. Warmup reduces latency on subsequent searches +""" + +import os +import time +from unittest.mock import patch + +import pytest + + +@pytest.fixture +def sample_index(tmp_path): + """Create a small sample index for testing.""" + from leann.api import LeannBuilder + + index_path = str(tmp_path / "test_warmup.hnsw") + texts = [f"This is test document {i} about topic {i % 3}" for i in range(20)] + + builder = LeannBuilder( + backend_name="hnsw", + embedding_model="facebook/contriever", + embedding_mode="sentence-transformers", + M=16, + efConstruction=100, + ) + + for text in texts: + builder.add_text(text) + + builder.build_index(index_path) + return index_path + + +class TestWarmupMethod: + """Test the warmup() method on LeannSearcher.""" + + @pytest.mark.skipif( + os.environ.get("CI") == "true", + reason="Skip model tests in CI to avoid memory issues", + ) + def test_warmup_method_exists(self, sample_index): + """Test that warmup method exists and is callable.""" + from leann.api import LeannSearcher + + searcher = LeannSearcher(sample_index, enable_warmup=False) + try: + assert hasattr(searcher, "warmup") + assert callable(searcher.warmup) + finally: + searcher.cleanup() + + @pytest.mark.skipif( + os.environ.get("CI") == "true", + reason="Skip model tests in CI to avoid memory issues", + ) + def test_warmup_returns_time(self, sample_index): + """Test that warmup() returns the time taken.""" + from leann.api import LeannSearcher + + searcher = LeannSearcher(sample_index, enable_warmup=False) + try: + warmup_time = searcher.warmup() + assert isinstance(warmup_time, float) + assert warmup_time >= 0 + finally: + searcher.cleanup() + + @pytest.mark.skipif( + os.environ.get("CI") == "true", + reason="Skip model tests in CI to avoid memory issues", + ) + def test_warmup_with_custom_port(self, sample_index): + """Test warmup with a custom port.""" + from leann.api import LeannSearcher + + searcher = LeannSearcher(sample_index, enable_warmup=False) + try: + # Use a different port + warmup_time = searcher.warmup(port=5560) + assert isinstance(warmup_time, float) + finally: + searcher.cleanup() + + +class TestAutoWarmup: + """Test automatic warmup on initialization.""" + + @pytest.mark.skipif( + os.environ.get("CI") == "true", + reason="Skip model tests in CI to avoid memory issues", + ) + def test_enable_warmup_false(self, sample_index): + """Test that enable_warmup=False doesn't trigger warmup.""" + from leann.api import LeannSearcher + + with patch.object(LeannSearcher, "warmup") as mock_warmup: + searcher = LeannSearcher(sample_index, enable_warmup=False) + mock_warmup.assert_not_called() + searcher.cleanup() + + @pytest.mark.skipif( + os.environ.get("CI") == "true", + reason="Skip model tests in CI to avoid memory issues", + ) + def test_enable_warmup_true(self, sample_index): + """Test that enable_warmup=True triggers warmup on init.""" + from leann.api import LeannSearcher + + # We can't easily mock the warmup call since it happens in __init__ + # So we test that _warmup_enabled is set + searcher = LeannSearcher(sample_index, enable_warmup=True) + try: + assert searcher._warmup_enabled is True + finally: + searcher.cleanup() + + +class TestWarmupLatencyImprovement: + """Test that warmup actually improves latency.""" + + @pytest.mark.skipif( + os.environ.get("CI") == "true", + reason="Skip model tests in CI to avoid memory issues", + ) + def test_warmup_reduces_first_search_latency(self, sample_index): + """Test that warmup reduces the latency of the first search.""" + from leann.api import LeannSearcher + + # Test WITHOUT warmup - first search should be slower + searcher_cold = LeannSearcher(sample_index, enable_warmup=False) + try: + start_cold = time.time() + _ = searcher_cold.search("test document", top_k=3) + cold_time = time.time() - start_cold + finally: + searcher_cold.cleanup() + + # Test WITH warmup - first search should be faster + searcher_warm = LeannSearcher(sample_index, enable_warmup=True) + try: + start_warm = time.time() + _ = searcher_warm.search("test document", top_k=3) + warm_time = time.time() - start_warm + + # The warmed-up first search should be faster + # (or at least not significantly slower) + # Note: warmup time is paid upfront, so first search after warmup + # should be fast + print(f"Cold first search: {cold_time:.3f}s") + print(f"Warm first search: {warm_time:.3f}s") + + # The warm search should complete (we don't assert strict timing + # as it can vary based on system load) + assert warm_time >= 0 + finally: + searcher_warm.cleanup() + + +class TestEmbeddingServerWarmup: + """Test warmup at the embedding server level.""" + + def test_hnsw_server_accepts_warmup_param(self): + """Test that HNSW embedding server accepts enable_warmup parameter.""" + import inspect + + from leann_backend_hnsw.hnsw_embedding_server import create_hnsw_embedding_server + + sig = inspect.signature(create_hnsw_embedding_server) + params = sig.parameters + assert "enable_warmup" in params + assert params["enable_warmup"].default is True + + def test_diskann_server_accepts_warmup_param(self): + """Test that DiskANN embedding server accepts enable_warmup parameter.""" + import inspect + + from leann_backend_diskann.diskann_embedding_server import ( + create_diskann_embedding_server, + ) + + sig = inspect.signature(create_diskann_embedding_server) + params = sig.parameters + assert "enable_warmup" in params + assert params["enable_warmup"].default is True + + +class TestServerManagerWarmup: + """Test warmup parameter passing in server manager.""" + + def test_build_command_with_warmup_enabled(self): + """Test that warmup enabled doesn't add --no-warmup flag.""" + from leann.embedding_server_manager import EmbeddingServerManager + + manager = EmbeddingServerManager("leann_backend_hnsw.hnsw_embedding_server") + cmd = manager._build_server_command( + port=5557, + model_name="test-model", + embedding_mode="sentence-transformers", + enable_warmup=True, + ) + + assert "--no-warmup" not in cmd + + def test_build_command_with_warmup_disabled(self): + """Test that warmup disabled adds --no-warmup flag.""" + from leann.embedding_server_manager import EmbeddingServerManager + + manager = EmbeddingServerManager("leann_backend_hnsw.hnsw_embedding_server") + cmd = manager._build_server_command( + port=5557, + model_name="test-model", + embedding_mode="sentence-transformers", + enable_warmup=False, + ) + + assert "--no-warmup" in cmd