diff --git a/amplifier_module_context_simple/__init__.py b/amplifier_module_context_simple/__init__.py index 5b856b3..e1d3269 100644 --- a/amplifier_module_context_simple/__init__.py +++ b/amplifier_module_context_simple/__init__.py @@ -1183,6 +1183,16 @@ def _calculate_budget(self, token_budget: int | None, provider: Any | None) -> i if context_window and max_output: reserved_output = int(max_output * output_reserve_fraction) budget = context_window - reserved_output - safety_margin + # Apply provider's cost-optimal budget cap if set + budget_cap = ( + provider.get_info().defaults or {} + ).get("context_budget_cap") + if budget_cap is not None: + budget = min(budget, budget_cap - safety_margin) + logger.info( + f"Budget capped by provider hint: {budget:,} " + f"(cap={budget_cap:,})" + ) logger.info( f"Budget from provider model info: {budget:,} " f"(context={context_window:,}, reserved_output={reserved_output:,} " @@ -1199,6 +1209,14 @@ def _calculate_budget(self, token_budget: int | None, provider: Any | None) -> i if context_window and max_output_tokens: reserved_output = int(max_output_tokens * output_reserve_fraction) budget = context_window - reserved_output - safety_margin + # Apply provider's cost-optimal budget cap if set + budget_cap = defaults.get("context_budget_cap") + if budget_cap is not None: + budget = min(budget, budget_cap - safety_margin) + logger.info( + f"Budget capped by provider hint: {budget:,} " + f"(cap={budget_cap:,})" + ) logger.info( f"Budget from provider defaults: {budget:,} " f"(context={context_window:,}, reserved_output={reserved_output:,} "