diff --git a/libs/superagent/app/memory/buffer_memory.py b/libs/superagent/app/memory/buffer_memory.py index 96c2467fa..5ba2fad18 100644 --- a/libs/superagent/app/memory/buffer_memory.py +++ b/libs/superagent/app/memory/buffer_memory.py @@ -7,7 +7,23 @@ from app.memory.message import BaseMessage DEFAULT_TOKEN_LIMIT_RATIO = 0.75 -DEFAULT_TOKEN_LIMIT = 3000 +DEFAULT_TOKEN_LIMIT = 3072 + + +def get_context_window(model: str) -> int: + max_input_tokens = model_cost.get(model, {}).get("max_input_tokens") + + # Some models don't have a provider prefix in their name + # But they point to the same model + # Example: claude-3-haiku-20240307 and anthropic/claude-3-haiku-20240307 + if not max_input_tokens: + model_parts = model.split("/", 1) + if len(model_parts) > 1: + model_without_prefix = model_parts[1] + max_input_tokens = model_cost.get(model_without_prefix, {}).get( + "max_input_tokens", DEFAULT_TOKEN_LIMIT + ) + return max_input_tokens class BufferMemory(BaseMemory): @@ -21,8 +37,9 @@ def __init__( self.memory_store = memory_store self.tokenizer_fn = tokenizer_fn self.model = model - context_window = model_cost.get(self.model, {}).get("max_input_tokens") - self.context_window = max_tokens or context_window * DEFAULT_TOKEN_LIMIT_RATIO + self.context_window = ( + max_tokens or get_context_window(model=model) * DEFAULT_TOKEN_LIMIT_RATIO + ) def add_message(self, message: BaseMessage) -> None: self.memory_store.add_message(message)