Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions amplifier_module_context_simple/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@

logger = logging.getLogger(__name__)

# Budget policy for extended context windows (e.g. Anthropic 1M beta).
# When context_window exceeds this threshold, use fraction-based budgeting
# instead of the standard (context_window - reserved_output) formula.
# This keeps sessions in the standard pricing zone while using the extended
# window as a safety net against overflow errors.
EXTENDED_CONTEXT_THRESHOLD = 200_000
DEFAULT_EXTENDED_BUDGET_FRACTION = 0.15


async def mount(coordinator: ModuleCoordinator, config: dict[str, Any] | None = None):
"""
Expand All @@ -47,6 +55,8 @@ async def mount(coordinator: ModuleCoordinator, config: dict[str, Any] | None =
- compaction_notice_token_reserve: Tokens to reserve for notice (default: 800)
- compaction_notice_verbosity: Notice detail level - "minimal", "normal", "verbose" (default: "normal")
- compaction_notice_min_level: Only show notice if compaction level >= this (default: 1)
- extended_context_budget_fraction: Budget as fraction of context_window when
context_window > 200k (default: 0.15). Only applies to extended context windows.

Returns:
Optional cleanup function
Expand All @@ -65,6 +75,11 @@ async def mount(coordinator: ModuleCoordinator, config: dict[str, Any] | None =
),
compaction_notice_verbosity=config.get("compaction_notice_verbosity", "normal"),
compaction_notice_min_level=config.get("compaction_notice_min_level", 1),
extended_context_budget_fraction=float(
config.get(
"extended_context_budget_fraction", DEFAULT_EXTENDED_BUDGET_FRACTION
)
),
hooks=getattr(coordinator, "hooks", None),
)
await coordinator.mount("context", context)
Expand Down Expand Up @@ -118,6 +133,7 @@ def __init__(
compaction_notice_token_reserve: int = 800,
compaction_notice_verbosity: str = "normal",
compaction_notice_min_level: int = 1,
extended_context_budget_fraction: float = DEFAULT_EXTENDED_BUDGET_FRACTION,
hooks: Any = None,
):
"""
Expand All @@ -134,6 +150,9 @@ def __init__(
compaction_notice_token_reserve: Tokens to reserve for notice
compaction_notice_verbosity: Notice detail level ("minimal", "normal", "verbose")
compaction_notice_min_level: Only show notice if compaction level >= this
extended_context_budget_fraction: Budget as fraction of context_window for
extended context windows (>200k). Default 0.15 keeps sessions in the
standard pricing zone. Only applies when context_window > 200k.
hooks: Optional hooks instance for emitting observability events
"""
self.messages: list[dict[str, Any]] = []
Expand All @@ -147,6 +166,7 @@ def __init__(
self.compaction_notice_token_reserve = compaction_notice_token_reserve
self.compaction_notice_verbosity = compaction_notice_verbosity
self.compaction_notice_min_level = compaction_notice_min_level
self.extended_context_budget_fraction = extended_context_budget_fraction
self._hooks = hooks
self._last_compaction_stats: dict[str, Any] | None = None
self._system_prompt_factory: Callable[[], Awaitable[str]] | None = None
Expand Down Expand Up @@ -1181,6 +1201,20 @@ def _calculate_budget(self, token_budget: int | None, provider: Any | None) -> i
context_window = getattr(model_info, "context_window", None)
max_output = getattr(model_info, "max_output_tokens", None)
if context_window and max_output:
if context_window > EXTENDED_CONTEXT_THRESHOLD:
budget = (
int(
context_window
* self.extended_context_budget_fraction
)
- safety_margin
)
logger.info(
f"Budget from extended context fraction: {budget:,} "
f"({self.extended_context_budget_fraction:.0%} of "
f"{context_window:,})"
)
return budget
reserved_output = int(max_output * output_reserve_fraction)
budget = context_window - reserved_output - safety_margin
logger.info(
Expand All @@ -1197,6 +1231,17 @@ def _calculate_budget(self, token_budget: int | None, provider: Any | None) -> i
max_output_tokens = defaults.get("max_output_tokens")

if context_window and max_output_tokens:
if context_window > EXTENDED_CONTEXT_THRESHOLD:
budget = (
int(context_window * self.extended_context_budget_fraction)
- safety_margin
)
logger.info(
f"Budget from extended context fraction: {budget:,} "
f"({self.extended_context_budget_fraction:.0%} of "
f"{context_window:,})"
)
return budget
reserved_output = int(max_output_tokens * output_reserve_fraction)
budget = context_window - reserved_output - safety_margin
logger.info(
Expand Down