diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index d15a88b00..82e27530d 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -994,12 +994,32 @@ def create_embedding( }, } + def rank( + self, + query: str, + documents: List[str] + ) -> List[float]: + """Rank a query against a list of docs + + Args: + query: The utf-8 encoded query string. + documents: The utf-8 encoded list of documents. + + Returns: + A list of rank scores. + """ + input = [f"{query}{doc}" for doc in documents] + embeds = self.embed(input, special_tokenize=True) + rank_scores = [embed[0] for embed in embeds] + return rank_scores + def embed( self, input: Union[str, List[str]], normalize: bool = False, truncate: bool = True, return_count: bool = False, + special_tokenize: bool = False, ): """Embed a string. @@ -1071,7 +1091,7 @@ def decode_batch(seq_sizes: List[int]): # accumulate batches and encode for text in inputs: - tokens = self.tokenize(text.encode("utf-8")) + tokens = self.tokenize(text.encode("utf-8"), special=special_tokenize) if truncate: tokens = tokens[:n_batch]