From ef4fef539277d67e2d7398d11d0962ed416527de Mon Sep 17 00:00:00 2001 From: TY <42710806+CTY-git@users.noreply.github.com> Date: Fri, 1 Nov 2024 19:47:55 +0800 Subject: [PATCH] Fix hf regression due to prompt truncation (#994) --- patchwork/common/client/llm/aio.py | 4 ++-- patchwork/common/client/llm/openai_.py | 4 ++++ patchwork/steps/JoinList/JoinList.py | 7 +++++-- pyproject.toml | 2 +- 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/patchwork/common/client/llm/aio.py b/patchwork/common/client/llm/aio.py index 02faef0e..cdb9714c 100644 --- a/patchwork/common/client/llm/aio.py +++ b/patchwork/common/client/llm/aio.py @@ -29,11 +29,11 @@ def get_models(self) -> set[str]: def is_model_supported(self, model: str) -> bool: return any(client.is_model_supported(model) for client in self.__clients) - def is_prompt_supported(self, messages: Iterable[ChatCompletionMessageParam], model: str) -> bool: + def is_prompt_supported(self, messages: Iterable[ChatCompletionMessageParam], model: str) -> int: for client in self.__clients: if client.is_model_supported(model): return client.is_prompt_supported(messages, model) - return False + return -1 def truncate_messages( self, messages: Iterable[ChatCompletionMessageParam], model: str diff --git a/patchwork/common/client/llm/openai_.py b/patchwork/common/client/llm/openai_.py index 71a7009a..1ce603f0 100644 --- a/patchwork/common/client/llm/openai_.py +++ b/patchwork/common/client/llm/openai_.py @@ -62,6 +62,10 @@ def __get_model_limits(self, model: str) -> int: return self.__MODEL_LIMITS.get(model, 128_000) def is_prompt_supported(self, messages: Iterable[ChatCompletionMessageParam], model: str) -> int: + # might not implement model endpoint + if self.__is_not_openai_url(): + return 1 + model_limit = self.__get_model_limits(model) token_count = 0 encoding = tiktoken.encoding_for_model(model) diff --git a/patchwork/steps/JoinList/JoinList.py b/patchwork/steps/JoinList/JoinList.py index b1b4bb53..893ebb4c 100644 --- a/patchwork/steps/JoinList/JoinList.py +++ b/patchwork/steps/JoinList/JoinList.py @@ -24,11 +24,14 @@ def run(self): if isinstance(item, str): items.append(item) elif isinstance(item, dict): + is_added = False for possible_key in self.possible_keys: if possible_key in item.keys(): items.append(item.get(possible_key)) - else: - items.append(json.dumps(item)) + is_added = True + break + if not is_added: + items.append(json.dumps(item)) else: items.append(str(item)) diff --git a/pyproject.toml b/pyproject.toml index 559d0802..16473be0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "patchwork-cli" -version = "0.0.75" +version = "0.0.76" description = "" authors = ["patched.codes"] license = "AGPL"