Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,23 @@ def create_diskann_embedding_server(
model_name: str = "sentence-transformers/all-mpnet-base-v2",
embedding_mode: str = "sentence-transformers",
distance_metric: str = "l2",
enable_warmup: bool = True,
):
"""
Create and start a ZMQ-based embedding server for DiskANN backend.
Uses ROUTER socket and protobuf communication as required by DiskANN C++ implementation.

Args:
passages_file: Path to the metadata file (.meta.json)
zmq_port: Port for ZMQ server
model_name: Name of the embedding model to use
embedding_mode: Embedding backend mode
distance_metric: Distance metric (l2, mips, cosine)
enable_warmup: If True, pre-load model and run warmup embedding on startup
"""
logger.info(f"Starting DiskANN server on port {zmq_port} with model {model_name}")
logger.info(f"Using embedding mode: {embedding_mode}")
logger.info(f"Warmup enabled: {enable_warmup}")

# Add leann-core to path for unified embedding computation
current_dir = Path(__file__).parent
Expand All @@ -72,6 +82,24 @@ def create_diskann_embedding_server(
finally:
sys.path.pop(0)

# Warmup: Pre-load the embedding model by computing a dummy embedding
# This ensures the model is cached and ready for fast subsequent queries
if enable_warmup:
warmup_start = time.time()
logger.info("Starting model warmup...")
try:
# Compute a dummy embedding to trigger model loading and caching
_ = compute_embeddings(
["warmup query for model preloading"],
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
warmup_time = time.time() - warmup_start
logger.info(f"Model warmup completed in {warmup_time:.2f}s")
except Exception as e:
logger.warning(f"Model warmup failed (non-fatal): {e}")

# Check port availability
import socket

Expand Down Expand Up @@ -479,14 +507,29 @@ def signal_handler(sig, frame):
choices=["l2", "mips", "cosine"],
help="Distance metric for similarity computation",
)
parser.add_argument(
"--enable-warmup",
action="store_true",
default=True,
help="Pre-load embedding model on startup for faster first query (default: True)",
)
parser.add_argument(
"--no-warmup",
action="store_true",
help="Disable warmup (lazy model loading)",
)

args = parser.parse_args()

# Determine warmup setting (--no-warmup takes precedence)
enable_warmup = not args.no_warmup

# Create and start the DiskANN embedding server
create_diskann_embedding_server(
passages_file=args.passages_file,
zmq_port=args.zmq_port,
model_name=args.model_name,
embedding_mode=args.embedding_mode,
distance_metric=args.distance_metric,
enable_warmup=enable_warmup,
)
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,23 @@ def create_hnsw_embedding_server(
model_name: str = "sentence-transformers/all-mpnet-base-v2",
distance_metric: str = "mips",
embedding_mode: str = "sentence-transformers",
enable_warmup: bool = True,
):
"""
Create and start a ZMQ-based embedding server for HNSW backend.
Simplified version using unified embedding computation module.

Args:
passages_file: Path to the metadata file (.meta.json)
zmq_port: Port for ZMQ server
model_name: Name of the embedding model to use
distance_metric: Distance metric (mips, l2, cosine)
embedding_mode: Embedding backend mode
enable_warmup: If True, pre-load model and run warmup embedding on startup
"""
logger.info(f"Starting HNSW server on port {zmq_port} with model {model_name}")
logger.info(f"Using embedding mode: {embedding_mode}")
logger.info(f"Warmup enabled: {enable_warmup}")

# Add leann-core to path for unified embedding computation
current_dir = Path(__file__).parent
Expand All @@ -85,6 +95,24 @@ def create_hnsw_embedding_server(
finally:
sys.path.pop(0)

# Warmup: Pre-load the embedding model by computing a dummy embedding
# This ensures the model is cached and ready for fast subsequent queries
if enable_warmup:
warmup_start = time.time()
logger.info("Starting model warmup...")
try:
# Compute a dummy embedding to trigger model loading and caching
_ = compute_embeddings(
["warmup query for model preloading"],
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
warmup_time = time.time() - warmup_start
logger.info(f"Model warmup completed in {warmup_time:.2f}s")
except Exception as e:
logger.warning(f"Model warmup failed (non-fatal): {e}")

# Check port availability
import socket

Expand Down Expand Up @@ -481,14 +509,29 @@ def signal_handler(sig, frame):
choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode",
)
parser.add_argument(
"--enable-warmup",
action="store_true",
default=True,
help="Pre-load embedding model on startup for faster first query (default: True)",
)
parser.add_argument(
"--no-warmup",
action="store_true",
help="Disable warmup (lazy model loading)",
)

args = parser.parse_args()

# Determine warmup setting (--no-warmup takes precedence)
enable_warmup = not args.no_warmup

# Create and start the HNSW embedding server
create_hnsw_embedding_server(
passages_file=args.passages_file,
zmq_port=args.zmq_port,
model_name=args.model_name,
distance_metric=args.distance_metric,
embedding_mode=args.embedding_mode,
enable_warmup=enable_warmup,
)
69 changes: 69 additions & 0 deletions packages/leann-core/src/leann/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,15 @@ def update_index(self, index_path: str):

class LeannSearcher:
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
"""Initialize a LeannSearcher for searching an existing index.

Args:
index_path: Path to the .leann index file
enable_warmup: If True, pre-load the embedding model on initialization
for faster first search. This adds initialization time
but reduces latency on the first search query.
**backend_kwargs: Additional arguments passed to the backend searcher
"""
# Fix path resolution for Colab and other environments
if not Path(index_path).is_absolute():
index_path = str(Path(index_path).resolve())
Expand Down Expand Up @@ -903,6 +912,12 @@ def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwarg
index_path, **final_kwargs
)

# Auto-warmup if requested - this pre-loads the embedding model
# to avoid cold-start latency on the first search
self._warmup_enabled = enable_warmup
if enable_warmup:
self.warmup()

def search(
self,
query: str,
Expand Down Expand Up @@ -1165,6 +1180,60 @@ def _python_regex_search(self, query: str, top_k: int = 5) -> list[SearchResult]
matches.sort(key=lambda x: x.score, reverse=True)
return matches[:top_k]

def warmup(self, port: int = 5557) -> float:
"""Pre-warm the embedding server and model for faster subsequent searches.

This method starts the embedding server (if not already running) and
ensures the embedding model is loaded and cached. Call this before
your first search to avoid cold-start latency.

Args:
port: ZMQ port for the embedding server (default: 5557)

Returns:
Time taken for warmup in seconds

Example:
>>> searcher = LeannSearcher("path/to/index.leann")
>>> warmup_time = searcher.warmup()
>>> print(f"Warmup completed in {warmup_time:.2f}s")
>>> # Subsequent searches will be faster
>>> results = searcher.search("my query")
"""
import time

start_time = time.time()
logger.info("Starting LeannSearcher warmup...")

try:
# Start the embedding server with warmup enabled
# This triggers model loading in the server process
zmq_port = self.backend_impl._ensure_server_running(
self.meta_path_str,
port=port,
enable_warmup=True,
)

# Optionally, do a dummy query to ensure everything is fully warmed up
# This tests the full path including ZMQ communication
try:
_ = self.backend_impl.compute_query_embedding(
"warmup test query",
use_server_if_available=True,
zmq_port=zmq_port,
)
except Exception as e:
logger.warning(f"Warmup query failed (non-fatal): {e}")

warmup_time = time.time() - start_time
logger.info(f"LeannSearcher warmup completed in {warmup_time:.2f}s")
return warmup_time

except Exception as e:
warmup_time = time.time() - start_time
logger.warning(f"Warmup partially failed after {warmup_time:.2f}s: {e}")
return warmup_time

def cleanup(self):
"""Explicitly cleanup embedding server resources.
This method should be called after you're done using the searcher,
Expand Down
3 changes: 3 additions & 0 deletions packages/leann-core/src/leann/embedding_server_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,9 @@ def _build_server_command(
command.extend(["--embedding-mode", embedding_mode])
if kwargs.get("distance_metric"):
command.extend(["--distance-metric", kwargs["distance_metric"]])
# Control warmup behavior - default is enabled, use --no-warmup to disable
if not kwargs.get("enable_warmup", True):
command.append("--no-warmup")

return command

Expand Down
Loading
Loading