Skip to content

Commit

Permalink
Count tokens per LLM call (#694)
Browse files Browse the repository at this point in the history
* count tokens

* bump version
  • Loading branch information
CTY-git authored Aug 23, 2024
1 parent b6fbfbf commit e1663fc
Show file tree
Hide file tree
Showing 10 changed files with 48 additions and 8 deletions.
33 changes: 28 additions & 5 deletions patchwork/steps/CallLLM/CallLLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import json
import os
from dataclasses import dataclass
from itertools import islice
from pathlib import Path
from pprint import pformat
from textwrap import indent
from typing import NamedTuple

from rich.markup import escape

Expand All @@ -19,6 +21,14 @@
from patchwork.steps.CallLLM.typed import CallLLMInputs, CallLLMOutputs


@dataclass
class _InnerCallLLMResponse:
prompts: list[dict]
response: str
request_token: int
response_token: int


class CallLLM(Step, input_class=CallLLMInputs, output_class=CallLLMOutputs):
def __init__(self, inputs: dict):
super().__init__(inputs)
Expand Down Expand Up @@ -115,13 +125,21 @@ def run(self) -> dict:

contents = self.__call(prompts)

openai_responses = []
request_tokens = []
response_tokens = []
for content in contents:
openai_responses.append(content.response)
request_tokens.append(content.request_token)
response_tokens.append(content.response_token)

if self.save_responses_to_file:
self.__persist_to_file(contents)
self.__persist_to_file(openai_responses)

return dict(openai_responses=contents)
return dict(openai_responses=openai_responses, request_tokens=request_tokens, response_tokens=response_tokens)

def __call(self, prompts: list[dict]) -> list[str]:
contents = []
def __call(self, prompts: list[list[dict]]) -> list[_InnerCallLLMResponse]:
contents: list[_InnerCallLLMResponse] = []

# Parse model arguments
parsed_model_args = self.__parse_model_args()
Expand Down Expand Up @@ -150,7 +168,12 @@ def __call(self, prompts: list[dict]) -> list[str]:
content = completion.choices[0].message.content
logger.trace(f"Response received: \n{escape(indent(content, ' '))}")

contents.append(content)
contents.append(_InnerCallLLMResponse(
prompts=prompt,
response=content,
request_token=completion.usage.prompt_tokens,
response_token=completion.usage.completion_tokens
))

return contents

Expand Down
2 changes: 2 additions & 0 deletions patchwork/steps/CallLLM/typed.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,5 @@ class CallLLMInputs(TypedDict, total=False):

class CallLLMOutputs(TypedDict):
openai_responses: List[str]
request_tokens: List[int]
response_tokens: List[int]
2 changes: 2 additions & 0 deletions patchwork/steps/LLM/LLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,7 @@ def run(self) -> dict:
prompts=prepare_prompt_outputs.get("prompts"),
openai_responses=call_llm_outputs.get("openai_responses"),
extracted_responses=extract_model_response_outputs.get("extracted_responses"),
request_tokens=call_llm_outputs.get("request_tokens"),
response_tokens=call_llm_outputs.get("response_tokens"),
)
)
2 changes: 2 additions & 0 deletions patchwork/steps/LLM/typed.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,7 @@ class LLMOutputs(TypedDict):
prompts: List[Dict]
# CallLLMOutputs
openai_responses: List[str]
request_tokens: List[int]
response_tokens: List[int]
# ExtractModelResponseOutputs
extracted_responses: List[Dict[str, str]]
2 changes: 2 additions & 0 deletions patchwork/steps/SimplifiedLLM/SimplifiedLLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,7 @@ def run(self) -> dict:
prompts=prepare_prompt_outputs.get("prompts"),
openai_responses=call_llm_outputs.get("openai_responses"),
extracted_responses=extract_model_response_outputs.get("extracted_responses"),
request_tokens=call_llm_outputs.get("request_tokens"),
response_tokens=call_llm_outputs.get("response_tokens"),
)
)
3 changes: 3 additions & 0 deletions patchwork/steps/SimplifiedLLM/typed.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class SimplifiedLLMInputs(__SimplifiedLLMInputsRequired, total=False):
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key"])
]
json: Annotated[bool, StepTypeConfig(is_config=True)]
json_example_schema: Annotated[str, StepTypeConfig(is_config=True)]
# ExtractModelResponseInputs
response_partitions: Annotated[Dict[str, List[str]], StepTypeConfig(is_config=True)]

Expand All @@ -48,5 +49,7 @@ class SimplifiedLLMOutputs(TypedDict):
prompts: List[Dict]
# CallLLMOutputs
openai_responses: List[str]
request_tokens: List[int]
response_tokens: List[int]
# ExtractModelResponseOutputs
extracted_responses: List[Dict[str, str]]
2 changes: 2 additions & 0 deletions patchwork/steps/SimplifiedLLMOnce/SimplifiedLLMOnce.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,7 @@ def run(self) -> dict:
prompt=llm_output.get("prompts")[0],
openai_response=llm_output.get("openai_responses")[0],
extracted_response=llm_output.get("extracted_responses")[0],
request_tokens=llm_output.get("request_tokens")[0],
response_tokens=llm_output.get("response_tokens")[0],
)
)
2 changes: 2 additions & 0 deletions patchwork/steps/SimplifiedLLMOnce/typed.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,7 @@ class SimplifiedLLMOnceOutputs(TypedDict):
prompt: Dict
# CallLLMOutputs
openai_response: str
request_tokens: int
response_tokens: int
# ExtractModelResponseOutputs
extracted_response: Dict[str, str]
6 changes: 4 additions & 2 deletions patchwork/steps/SimplifiedLLMOncePB/SimplifiedLLMOncePB.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def run(self) -> dict:
)
llm_output = llm.run()

return dict(
return {
**llm_output.get("extracted_responses")[0],
)
"request_tokens": llm_output.get("request_tokens")[0],
"response_tokens": llm_output.get("response_tokens")[0],
}
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "patchwork-cli"
version = "0.0.49"
version = "0.0.50"
description = ""
authors = ["patched.codes"]
license = "AGPL"
Expand Down

0 comments on commit e1663fc

Please sign in to comment.