diff --git a/amplifier_module_context_simple/__init__.py b/amplifier_module_context_simple/__init__.py index ab57c14..ad8eabb 100644 --- a/amplifier_module_context_simple/__init__.py +++ b/amplifier_module_context_simple/__init__.py @@ -29,6 +29,14 @@ logger = logging.getLogger(__name__) +# Budget policy for extended context windows (e.g. Anthropic 1M beta). +# When context_window exceeds this threshold, use fraction-based budgeting +# instead of the standard (context_window - reserved_output) formula. +# This keeps sessions in the standard pricing zone while using the extended +# window as a safety net against overflow errors. +EXTENDED_CONTEXT_THRESHOLD = 200_000 +DEFAULT_EXTENDED_BUDGET_FRACTION = 0.15 + async def mount(coordinator: ModuleCoordinator, config: dict[str, Any] | None = None): """ @@ -47,6 +55,8 @@ async def mount(coordinator: ModuleCoordinator, config: dict[str, Any] | None = - compaction_notice_token_reserve: Tokens to reserve for notice (default: 800) - compaction_notice_verbosity: Notice detail level - "minimal", "normal", "verbose" (default: "normal") - compaction_notice_min_level: Only show notice if compaction level >= this (default: 1) + - extended_context_budget_fraction: Budget as fraction of context_window when + context_window > 200k (default: 0.15). Only applies to extended context windows. Returns: Optional cleanup function @@ -65,6 +75,11 @@ async def mount(coordinator: ModuleCoordinator, config: dict[str, Any] | None = ), compaction_notice_verbosity=config.get("compaction_notice_verbosity", "normal"), compaction_notice_min_level=config.get("compaction_notice_min_level", 1), + extended_context_budget_fraction=float( + config.get( + "extended_context_budget_fraction", DEFAULT_EXTENDED_BUDGET_FRACTION + ) + ), hooks=getattr(coordinator, "hooks", None), ) await coordinator.mount("context", context) @@ -118,6 +133,7 @@ def __init__( compaction_notice_token_reserve: int = 800, compaction_notice_verbosity: str = "normal", compaction_notice_min_level: int = 1, + extended_context_budget_fraction: float = DEFAULT_EXTENDED_BUDGET_FRACTION, hooks: Any = None, ): """ @@ -134,6 +150,9 @@ def __init__( compaction_notice_token_reserve: Tokens to reserve for notice compaction_notice_verbosity: Notice detail level ("minimal", "normal", "verbose") compaction_notice_min_level: Only show notice if compaction level >= this + extended_context_budget_fraction: Budget as fraction of context_window for + extended context windows (>200k). Default 0.15 keeps sessions in the + standard pricing zone. Only applies when context_window > 200k. hooks: Optional hooks instance for emitting observability events """ self.messages: list[dict[str, Any]] = [] @@ -147,6 +166,7 @@ def __init__( self.compaction_notice_token_reserve = compaction_notice_token_reserve self.compaction_notice_verbosity = compaction_notice_verbosity self.compaction_notice_min_level = compaction_notice_min_level + self.extended_context_budget_fraction = extended_context_budget_fraction self._hooks = hooks self._last_compaction_stats: dict[str, Any] | None = None self._system_prompt_factory: Callable[[], Awaitable[str]] | None = None @@ -1181,6 +1201,20 @@ def _calculate_budget(self, token_budget: int | None, provider: Any | None) -> i context_window = getattr(model_info, "context_window", None) max_output = getattr(model_info, "max_output_tokens", None) if context_window and max_output: + if context_window > EXTENDED_CONTEXT_THRESHOLD: + budget = ( + int( + context_window + * self.extended_context_budget_fraction + ) + - safety_margin + ) + logger.info( + f"Budget from extended context fraction: {budget:,} " + f"({self.extended_context_budget_fraction:.0%} of " + f"{context_window:,})" + ) + return budget reserved_output = int(max_output * output_reserve_fraction) budget = context_window - reserved_output - safety_margin logger.info( @@ -1197,6 +1231,17 @@ def _calculate_budget(self, token_budget: int | None, provider: Any | None) -> i max_output_tokens = defaults.get("max_output_tokens") if context_window and max_output_tokens: + if context_window > EXTENDED_CONTEXT_THRESHOLD: + budget = ( + int(context_window * self.extended_context_budget_fraction) + - safety_margin + ) + logger.info( + f"Budget from extended context fraction: {budget:,} " + f"({self.extended_context_budget_fraction:.0%} of " + f"{context_window:,})" + ) + return budget reserved_output = int(max_output_tokens * output_reserve_fraction) budget = context_window - reserved_output - safety_margin logger.info(