Skip to content

Commit

Permalink
Fix hf regression due to prompt truncation (#994)
Browse files Browse the repository at this point in the history
  • Loading branch information
CTY-git authored Nov 1, 2024
1 parent bbae0ec commit ef4fef5
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 5 deletions.
4 changes: 2 additions & 2 deletions patchwork/common/client/llm/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ def get_models(self) -> set[str]:
def is_model_supported(self, model: str) -> bool:
return any(client.is_model_supported(model) for client in self.__clients)

def is_prompt_supported(self, messages: Iterable[ChatCompletionMessageParam], model: str) -> bool:
def is_prompt_supported(self, messages: Iterable[ChatCompletionMessageParam], model: str) -> int:
for client in self.__clients:
if client.is_model_supported(model):
return client.is_prompt_supported(messages, model)
return False
return -1

def truncate_messages(
self, messages: Iterable[ChatCompletionMessageParam], model: str
Expand Down
4 changes: 4 additions & 0 deletions patchwork/common/client/llm/openai_.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def __get_model_limits(self, model: str) -> int:
return self.__MODEL_LIMITS.get(model, 128_000)

def is_prompt_supported(self, messages: Iterable[ChatCompletionMessageParam], model: str) -> int:
# might not implement model endpoint
if self.__is_not_openai_url():
return 1

model_limit = self.__get_model_limits(model)
token_count = 0
encoding = tiktoken.encoding_for_model(model)
Expand Down
7 changes: 5 additions & 2 deletions patchwork/steps/JoinList/JoinList.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@ def run(self):
if isinstance(item, str):
items.append(item)
elif isinstance(item, dict):
is_added = False
for possible_key in self.possible_keys:
if possible_key in item.keys():
items.append(item.get(possible_key))
else:
items.append(json.dumps(item))
is_added = True
break
if not is_added:
items.append(json.dumps(item))
else:
items.append(str(item))

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "patchwork-cli"
version = "0.0.75"
version = "0.0.76"
description = ""
authors = ["patched.codes"]
license = "AGPL"
Expand Down

0 comments on commit ef4fef5

Please sign in to comment.