diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 7e9a6af23..4d37bf959 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -932,7 +932,8 @@ def generate( sample_idx += 1 if stopping_criteria is not None and stopping_criteria( - self._input_ids[: sample_idx], self._scores[sample_idx - self.n_tokens, :] + self._input_ids[:sample_idx], + self._scores[sample_idx - self.n_tokens, :], ): return tokens_or_none = yield token @@ -958,7 +959,11 @@ def generate( ) def create_embedding( - self, input: Union[str, List[str]], model: Optional[str] = None + self, + input: Union[str, List[str]], + model: Optional[str] = None, + normalize: bool = False, + truncate: bool = True, ) -> CreateEmbeddingResponse: """Embed a string. @@ -975,7 +980,9 @@ def create_embedding( # get numeric embeddings embeds: Union[List[List[float]], List[List[List[float]]]] total_tokens: int - embeds, total_tokens = self.embed(input, return_count=True) # type: ignore + embeds, total_tokens = self.embed( + input=input, normalize=normalize, truncate=truncate, return_count=True + ) # convert to CreateEmbeddingResponse data: List[Embedding] = [ @@ -1313,7 +1320,7 @@ def logit_bias_processor( if seed is not None: self.set_seed(seed) else: - self.set_seed(random.Random(self._seed).randint(0, 2 ** 32)) + self.set_seed(random.Random(self._seed).randint(0, 2**32)) finish_reason = "length" multibyte_fix = 0 @@ -2054,7 +2061,10 @@ def create_chat_completion_openai_v1( stream = kwargs.get("stream", False) # type: ignore assert isinstance(stream, bool) if stream: - return (ChatCompletionChunk(**chunk) for chunk in self.create_chat_completion(*args, **kwargs)) # type: ignore + return ( + ChatCompletionChunk(**chunk) + for chunk in self.create_chat_completion(*args, **kwargs) + ) # type: ignore else: return ChatCompletion(**self.create_chat_completion(*args, **kwargs)) # type: ignore except ImportError: @@ -2314,7 +2324,11 @@ def from_pretrained( if additional_files: for additonal_file_name in additional_files: # find the additional shard file: - matching_additional_files = [file for file in file_list if fnmatch.fnmatch(file, additonal_file_name)] + matching_additional_files = [ + file + for file in file_list + if fnmatch.fnmatch(file, additonal_file_name) + ] if len(matching_additional_files) == 0: raise ValueError(