Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 27 additions & 33 deletions python/fast_plaid/search/fast_plaid.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def compute_kmeans( # noqa: PLR0913

def search_on_device( # noqa: PLR0913
device: str,
queries_embeddings: torch.Tensor,
queries_embeddings: list[torch.Tensor],
batch_size: int,
n_full_scores: int,
top_k: int,
Expand Down Expand Up @@ -154,6 +154,18 @@ def search_on_device( # noqa: PLR0913
]


def cleanup_embeddings(
embeddings: list[torch.Tensor] | torch.Tensor,
) -> list[torch.Tensor]:
"""Convert embeddings to a list and remove extra dimensions."""
if isinstance(embeddings, torch.Tensor):
embeddings = [embeddings[i] for i in range(embeddings.shape[0])]
return [
embedding.squeeze(0) if embedding.dim() == 3 else embedding
for embedding in embeddings
]


class FastPlaid:
"""A class for creating and searching a FastPlaid index.

Expand Down Expand Up @@ -288,15 +300,7 @@ def create( # noqa: PLR0913
Optional list of dictionaries containing metadata for each document.

"""
if isinstance(documents_embeddings, torch.Tensor):
documents_embeddings = [
documents_embeddings[i] for i in range(documents_embeddings.shape[0])
]

documents_embeddings = [
embedding.squeeze(0) if embedding.dim() == 3 else embedding
for embedding in documents_embeddings
]
documents_embeddings = cleanup_embeddings(documents_embeddings)
num_docs = len(documents_embeddings)

self._prepare_index_directory(index_path=self.index)
Expand Down Expand Up @@ -473,17 +477,8 @@ def search( # noqa: PLR0913, C901, PLR0912, PLR0915
corresponding inner list.

"""
if isinstance(queries_embeddings, list):
queries_embeddings = torch.nn.utils.rnn.pad_sequence(
sequences=[
embedding[0] if embedding.dim() == 3 else embedding
for embedding in queries_embeddings
],
batch_first=True,
padding_value=0.0,
)

num_queries = queries_embeddings.shape[0]
queries_embeddings = cleanup_embeddings(queries_embeddings)
num_queries = len(queries_embeddings)

if subset is not None:
if isinstance(subset, int):
Expand Down Expand Up @@ -529,16 +524,15 @@ def search( # noqa: PLR0913, C901, PLR0912, PLR0915
num_cpus = len(self.devices)

# Use torch.chunk to split the tensor into num_cpus
queries_embeddings_splits = torch.chunk(
input=queries_embeddings,
chunks=num_cpus,
dim=0,
)
queries_embeddings_splits = [
queries_embeddings[i : i + num_cpus]
for i in range(0, num_queries, num_cpus)
]

# Filter out empty chunks that torch.chunk might create
# if num_queries < num_cpus
non_empty_splits = [
split for split in queries_embeddings_splits if split.shape[0] > 0
split for split in queries_embeddings_splits if len(split) > 0
]
num_splits = len(non_empty_splits)

Expand All @@ -548,7 +542,7 @@ def search( # noqa: PLR0913, C901, PLR0912, PLR0915
if subset is not None:
current_idx = 0
for split in non_empty_splits:
size = split.shape[0]
size = len(split)
subset_splits.append(subset[current_idx : current_idx + size]) # type: ignore
current_idx += size

Expand Down Expand Up @@ -600,16 +594,16 @@ def search( # noqa: PLR0913, C901, PLR0912, PLR0915
subset=subset, # type: ignore
)

queries_embeddings_splits = torch.split(
tensor=queries_embeddings,
split_size_or_sections=len(self.devices),
)
queries_embeddings_splits = [
queries_embeddings[i : i + len(self.devices)]
for i in range(0, num_queries, len(self.devices))
]

num_splits = len(queries_embeddings_splits)
if subset is not None:
current_idx = 0
for split in queries_embeddings_splits:
size = split.shape[0]
size = len(split)
subset_splits.append(subset[current_idx : current_idx + size]) # type: ignore
current_idx += size
else:
Expand Down
9 changes: 7 additions & 2 deletions rust/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ fn load_and_search(
index: String,
torch_path: String,
device: String,
queries_embeddings: PyTensor,
queries_embeddings: Vec<PyTensor>,
search_parameters: &SearchParameters,
show_progress: bool,
preload_index: bool,
Expand All @@ -397,9 +397,14 @@ fn load_and_search(
Ok(Arc::new(loaded_index))
}?;

let queries_embeddings: Vec<_> = queries_embeddings
.into_iter()
.map(|tensor| tensor.to_kind(Kind::Half))
.collect();

// Perform the search
let results = search_many(
&queries_embeddings.to_kind(Kind::Half),
&queries_embeddings,
&index,
search_parameters,
device_tch,
Expand Down
21 changes: 10 additions & 11 deletions rust/search/search.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use anyhow::{anyhow, bail, Result};
use anyhow::{anyhow, Result};
use indicatif::{ProgressBar, ProgressIterator};
use pyo3::prelude::*;
use serde::Serialize;
use tch::{Device, IndexOp, Kind, Tensor};
use tch::{Device, Kind, Tensor};

use crate::search::load::LoadedIndex;
use crate::search::padding::direct_pad_sequences;
Expand Down Expand Up @@ -165,22 +165,21 @@ impl SearchParameters {
/// A `Result` with a `Vec<QueryResult>`. Individual search failures result in an empty
/// `QueryResult` for that specific query, ensuring the operation doesn't halt.
pub fn search_many(
queries: &Tensor,
queries: &Vec<Tensor>,
index: &LoadedIndex,
params: &SearchParameters,
device: Device,
show_progress: bool,
subset: Option<Vec<Vec<i64>>>,
) -> Result<Vec<QueryResult>> {
let [num_queries, _, query_dim] = queries.size()[..] else {
bail!(
"Expected a 3D tensor for queries, but got shape {:?}",
queries.size()
);
};
let num_queries = queries.len();
if num_queries == 0 {
return Ok(Vec::new());
}
let query_dim = queries[0].size()[queries[0].dim() - 1];

let search_closure = |query_index| {
let query_embedding = queries.i(query_index).to(device);
let search_closure = |query_index: usize| {
let query_embedding = &queries[query_index].to(device);

// Handle the per-query subset list
let query_subset = subset.as_ref().and_then(|s| s.get(query_index as usize));
Expand Down