Skip to content

Commit 2cb902d

Browse files
fix(anthropic): clean up cache_control in middleware to prevent fallback errors
Fixes issue where AnthropicPromptCachingMiddleware's cache_control parameter persisted in model_settings when ModelFallbackMiddleware switched to non-Anthropic models (OpenAI, Google), causing TypeError. The fix uses try/finally blocks to ensure cache_control is always removed from model_settings after handler execution, regardless of success or failure. This prevents the Anthropic-specific parameter from being passed to fallback models. Changes: - Added cleanup logic in wrap_model_call() and awrap_model_call() - Updated existing tests to verify cleanup behavior - Added comprehensive tests for both success and error cases Fixes #33709
1 parent 94d5271 commit 2cb902d

File tree

2 files changed

+236
-17
lines changed

2 files changed

+236
-17
lines changed

libs/partners/anthropic/langchain_anthropic/middleware/prompt_caching.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,12 @@ def wrap_model_call(
120120
return handler(request)
121121

122122
self._apply_cache_control(request)
123-
return handler(request)
123+
try:
124+
return handler(request)
125+
finally:
126+
# Clean up cache_control to prevent it from being passed to fallback models
127+
# that don't support this Anthropic-specific parameter
128+
request.model_settings.pop("cache_control", None)
124129

125130
async def awrap_model_call(
126131
self,
@@ -140,4 +145,9 @@ async def awrap_model_call(
140145
return await handler(request)
141146

142147
self._apply_cache_control(request)
143-
return await handler(request)
148+
try:
149+
return await handler(request)
150+
finally:
151+
# Clean up cache_control to prevent it from being passed to fallback models
152+
# that don't support this Anthropic-specific parameter
153+
request.model_settings.pop("cache_control", None)

libs/partners/anthropic/tests/unit_tests/middleware/test_prompt_caching.py

Lines changed: 224 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,18 @@ def test_anthropic_prompt_caching_middleware_initialization() -> None:
8282
model_settings={},
8383
)
8484

85+
# Track the state during handler execution
86+
settings_during_call = {}
87+
8588
def mock_handler(req: ModelRequest) -> ModelResponse:
89+
settings_during_call.update(req.model_settings)
8690
return ModelResponse(result=[AIMessage(content="mock response")])
8791

8892
middleware.wrap_model_call(fake_request, mock_handler)
89-
# Check that model_settings were passed through via the request
90-
assert fake_request.model_settings == {
91-
"cache_control": {"type": "ephemeral", "ttl": "5m"}
92-
}
93+
# Check that model_settings were passed through during handler execution
94+
assert settings_during_call == {"cache_control": {"type": "ephemeral", "ttl": "5m"}}
95+
# Verify cleanup after handler completes
96+
assert fake_request.model_settings == {}
9397

9498

9599
def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
@@ -162,15 +166,19 @@ async def test_anthropic_prompt_caching_middleware_async() -> None:
162166
model_settings={},
163167
)
164168

169+
# Track the state during handler execution
170+
settings_during_call = {}
171+
165172
async def mock_handler(req: ModelRequest) -> ModelResponse:
173+
settings_during_call.update(req.model_settings)
166174
return ModelResponse(result=[AIMessage(content="mock response")])
167175

168176
result = await middleware.awrap_model_call(fake_request, mock_handler)
169177
assert isinstance(result, ModelResponse)
170-
# Check that model_settings were passed through via the request
171-
assert fake_request.model_settings == {
172-
"cache_control": {"type": "ephemeral", "ttl": "1h"}
173-
}
178+
# Check that model_settings were passed through during handler execution
179+
assert settings_during_call == {"cache_control": {"type": "ephemeral", "ttl": "1h"}}
180+
# Verify cleanup after handler completes
181+
assert fake_request.model_settings == {}
174182

175183

176184
async def test_anthropic_prompt_caching_middleware_async_unsupported_model() -> None:
@@ -268,15 +276,19 @@ async def test_anthropic_prompt_caching_middleware_async_with_system_prompt() ->
268276
model_settings={},
269277
)
270278

279+
# Track the state during handler execution
280+
settings_during_call = {}
281+
271282
async def mock_handler(req: ModelRequest) -> ModelResponse:
283+
settings_during_call.update(req.model_settings)
272284
return ModelResponse(result=[AIMessage(content="mock response")])
273285

274286
result = await middleware.awrap_model_call(fake_request, mock_handler)
275287
assert isinstance(result, ModelResponse)
276288
# Cache control should be added when system prompt pushes count to minimum
277-
assert fake_request.model_settings == {
278-
"cache_control": {"type": "ephemeral", "ttl": "1h"}
279-
}
289+
assert settings_during_call == {"cache_control": {"type": "ephemeral", "ttl": "1h"}}
290+
# Verify cleanup after handler completes
291+
assert fake_request.model_settings == {}
280292

281293

282294
async def test_anthropic_prompt_caching_middleware_async_default_values() -> None:
@@ -300,12 +312,209 @@ async def test_anthropic_prompt_caching_middleware_async_default_values() -> Non
300312
model_settings={},
301313
)
302314

315+
# Track the state during handler execution
316+
settings_during_call = {}
317+
318+
async def mock_handler(req: ModelRequest) -> ModelResponse:
319+
settings_during_call.update(req.model_settings)
320+
return ModelResponse(result=[AIMessage(content="mock response")])
321+
322+
result = await middleware.awrap_model_call(fake_request, mock_handler)
323+
assert isinstance(result, ModelResponse)
324+
# Check that model_settings were added with default values during handler execution
325+
assert settings_during_call == {"cache_control": {"type": "ephemeral", "ttl": "5m"}}
326+
# Verify cleanup after handler completes
327+
assert fake_request.model_settings == {}
328+
329+
330+
def test_cache_control_cleanup_on_success() -> None:
331+
"""Test that cache_control is cleaned up after successful handler execution.
332+
333+
This test verifies the fix for issue #33709 where cache_control was persisting
334+
in model_settings and breaking fallback middleware with non-Anthropic models.
335+
"""
336+
middleware = AnthropicPromptCachingMiddleware()
337+
mock_chat_anthropic = MagicMock(spec=ChatAnthropic)
338+
339+
fake_request = ModelRequest(
340+
model=mock_chat_anthropic,
341+
messages=[HumanMessage("Hello")],
342+
system_prompt=None,
343+
tool_choice=None,
344+
tools=[],
345+
response_format=None,
346+
state={"messages": [HumanMessage("Hello")]},
347+
runtime=cast(Runtime, object()),
348+
model_settings={},
349+
)
350+
351+
# Track the state of model_settings during handler execution
352+
settings_during_call = {}
353+
354+
def mock_handler(req: ModelRequest) -> ModelResponse:
355+
# Capture model_settings during handler execution
356+
settings_during_call.update(req.model_settings)
357+
return ModelResponse(result=[AIMessage(content="mock response")])
358+
359+
result = middleware.wrap_model_call(fake_request, mock_handler)
360+
361+
# Verify cache_control was present during handler execution
362+
assert "cache_control" in settings_during_call
363+
assert settings_during_call["cache_control"] == {"type": "ephemeral", "ttl": "5m"}
364+
365+
# Verify cache_control is cleaned up after handler returns
366+
assert "cache_control" not in fake_request.model_settings
367+
assert fake_request.model_settings == {}
368+
assert isinstance(result, ModelResponse)
369+
370+
371+
def test_cache_control_cleanup_on_error() -> None:
372+
"""Test that cache_control is cleaned up even when handler raises exception.
373+
374+
This ensures cleanup happens in all cases, preventing cache_control from
375+
persisting when fallback middleware tries alternative models.
376+
"""
377+
middleware = AnthropicPromptCachingMiddleware()
378+
mock_chat_anthropic = MagicMock(spec=ChatAnthropic)
379+
380+
fake_request = ModelRequest(
381+
model=mock_chat_anthropic,
382+
messages=[HumanMessage("Hello")],
383+
system_prompt=None,
384+
tool_choice=None,
385+
tools=[],
386+
response_format=None,
387+
state={"messages": [HumanMessage("Hello")]},
388+
runtime=cast(Runtime, object()),
389+
model_settings={},
390+
)
391+
392+
# Track the state of model_settings during handler execution
393+
settings_during_call = {}
394+
395+
def failing_handler(req: ModelRequest) -> ModelResponse:
396+
# Capture model_settings before raising error
397+
settings_during_call.update(req.model_settings)
398+
msg = "Simulated API error"
399+
raise RuntimeError(msg)
400+
401+
# Handler should raise the exception
402+
with pytest.raises(RuntimeError, match="Simulated API error"):
403+
middleware.wrap_model_call(fake_request, failing_handler)
404+
405+
# Verify cache_control was present during handler execution
406+
assert "cache_control" in settings_during_call
407+
408+
# Verify cache_control is cleaned up even after exception
409+
assert "cache_control" not in fake_request.model_settings
410+
assert fake_request.model_settings == {}
411+
412+
413+
async def test_cache_control_cleanup_on_success_async() -> None:
414+
"""Test async cleanup of cache_control after successful handler execution."""
415+
middleware = AnthropicPromptCachingMiddleware()
416+
mock_chat_anthropic = MagicMock(spec=ChatAnthropic)
417+
418+
fake_request = ModelRequest(
419+
model=mock_chat_anthropic,
420+
messages=[HumanMessage("Hello")],
421+
system_prompt=None,
422+
tool_choice=None,
423+
tools=[],
424+
response_format=None,
425+
state={"messages": [HumanMessage("Hello")]},
426+
runtime=cast(Runtime, object()),
427+
model_settings={},
428+
)
429+
430+
# Track the state of model_settings during handler execution
431+
settings_during_call = {}
432+
303433
async def mock_handler(req: ModelRequest) -> ModelResponse:
434+
# Capture model_settings during handler execution
435+
settings_during_call.update(req.model_settings)
304436
return ModelResponse(result=[AIMessage(content="mock response")])
305437

306438
result = await middleware.awrap_model_call(fake_request, mock_handler)
439+
440+
# Verify cache_control was present during handler execution
441+
assert "cache_control" in settings_during_call
442+
assert settings_during_call["cache_control"] == {"type": "ephemeral", "ttl": "5m"}
443+
444+
# Verify cache_control is cleaned up after handler returns
445+
assert "cache_control" not in fake_request.model_settings
446+
assert fake_request.model_settings == {}
447+
assert isinstance(result, ModelResponse)
448+
449+
450+
async def test_cache_control_cleanup_on_error_async() -> None:
451+
"""Test async cleanup of cache_control even when handler raises exception."""
452+
middleware = AnthropicPromptCachingMiddleware()
453+
mock_chat_anthropic = MagicMock(spec=ChatAnthropic)
454+
455+
fake_request = ModelRequest(
456+
model=mock_chat_anthropic,
457+
messages=[HumanMessage("Hello")],
458+
system_prompt=None,
459+
tool_choice=None,
460+
tools=[],
461+
response_format=None,
462+
state={"messages": [HumanMessage("Hello")]},
463+
runtime=cast(Runtime, object()),
464+
model_settings={},
465+
)
466+
467+
# Track the state of model_settings during handler execution
468+
settings_during_call = {}
469+
470+
async def failing_handler(req: ModelRequest) -> ModelResponse:
471+
# Capture model_settings before raising error
472+
settings_during_call.update(req.model_settings)
473+
msg = "Simulated async API error"
474+
raise RuntimeError(msg)
475+
476+
# Handler should raise the exception
477+
with pytest.raises(RuntimeError, match="Simulated async API error"):
478+
await middleware.awrap_model_call(fake_request, failing_handler)
479+
480+
# Verify cache_control was present during handler execution
481+
assert "cache_control" in settings_during_call
482+
483+
# Verify cache_control is cleaned up even after exception
484+
assert "cache_control" not in fake_request.model_settings
485+
assert fake_request.model_settings == {}
486+
487+
488+
def test_no_cleanup_when_caching_not_applied() -> None:
489+
"""Test that cleanup doesn't interfere when caching is not applied.
490+
491+
When using an unsupported model or below min_messages_to_cache,
492+
cache_control should never be added or cleaned up.
493+
"""
494+
middleware = AnthropicPromptCachingMiddleware(
495+
unsupported_model_behavior="ignore",
496+
min_messages_to_cache=10,
497+
)
498+
499+
fake_request = ModelRequest(
500+
model=FakeToolCallingModel(), # Unsupported model
501+
messages=[HumanMessage("Hello")],
502+
system_prompt=None,
503+
tool_choice=None,
504+
tools=[],
505+
response_format=None,
506+
state={"messages": [HumanMessage("Hello")]},
507+
runtime=cast(Runtime, object()),
508+
model_settings={},
509+
)
510+
511+
def mock_handler(req: ModelRequest) -> ModelResponse:
512+
# Verify cache_control was never added
513+
assert "cache_control" not in req.model_settings
514+
return ModelResponse(result=[AIMessage(content="mock response")])
515+
516+
result = middleware.wrap_model_call(fake_request, mock_handler)
517+
518+
# Verify model_settings remain empty throughout
519+
assert fake_request.model_settings == {}
307520
assert isinstance(result, ModelResponse)
308-
# Check that model_settings were added with default values
309-
assert fake_request.model_settings == {
310-
"cache_control": {"type": "ephemeral", "ttl": "5m"}
311-
}

0 commit comments

Comments
 (0)