diff --git a/sentry_sdk/integrations/langchain.py b/sentry_sdk/integrations/langchain.py index 0b8bbd8049..ee9d09da51 100644 --- a/sentry_sdk/integrations/langchain.py +++ b/sentry_sdk/integrations/langchain.py @@ -23,6 +23,7 @@ from langchain_core.callbacks import ( manager, BaseCallbackHandler, + BaseCallbackManager, Callbacks, ) from langchain_core.agents import AgentAction, AgentFinish @@ -434,12 +435,47 @@ def new_configure( **kwargs, ) - callbacks_list = local_callbacks or [] + # Lambda for lazy initialization of the SentryLangchainCallback + sentry_handler_factory = lambda: SentryLangchainCallback( + integration.max_spans, + integration.include_prompts, + integration.tiktoken_encoding_name, + ) + + local_callbacks = local_callbacks or [] + + # Handle each possible type of local_callbacks. For each type, we + # extract the list of callbacks to check for SentryLangchainCallback, + # and define a function that would add the SentryLangchainCallback + # to the existing callbacks list. + if isinstance(local_callbacks, BaseCallbackManager): + callbacks_list = local_callbacks.handlers + manager = local_callbacks + + # For BaseCallbackManager, we want to copy the manager and add the + # SentryLangchainCallback to the copy. + def local_callbacks_with_sentry(): + # type: () -> Union[BaseCallbackManager, list[BaseCallbackHandler]] + new_manager = manager.copy() + new_manager.handlers = [*new_manager.handlers, sentry_handler_factory()] + return new_manager + + elif isinstance(local_callbacks, BaseCallbackHandler): + callbacks_list = [local_callbacks] - if isinstance(callbacks_list, BaseCallbackHandler): - callbacks_list = [callbacks_list] - elif not isinstance(callbacks_list, list): - logger.debug("Unknown callback type: %s", callbacks_list) + def local_callbacks_with_sentry(): + # type: () -> Union[BaseCallbackManager, list[BaseCallbackHandler]] + return [*callbacks_list, sentry_handler_factory()] + + elif isinstance(local_callbacks, list): + callbacks_list = local_callbacks + + def local_callbacks_with_sentry(): + # type: () -> Union[BaseCallbackManager, list[BaseCallbackHandler]] + return [*callbacks_list, sentry_handler_factory()] + + else: + logger.debug("Unknown callback type: %s", local_callbacks) # Just proceed with original function call return f( callback_manager_cls, @@ -457,20 +493,12 @@ def new_configure( isinstance(cb, SentryLangchainCallback) for cb in itertools.chain(callbacks_list, inheritable_callbacks_list) ): - # Avoid mutating the existing callbacks list - callbacks_list = [ - *callbacks_list, - SentryLangchainCallback( - integration.max_spans, - integration.include_prompts, - integration.tiktoken_encoding_name, - ), - ] + local_callbacks = local_callbacks_with_sentry() return f( callback_manager_cls, inheritable_callbacks, - callbacks_list, + local_callbacks, *args, **kwargs, ) diff --git a/tests/integrations/langchain/test_langchain.py b/tests/integrations/langchain/test_langchain.py index 8ace6d4821..71fb1b960e 100644 --- a/tests/integrations/langchain/test_langchain.py +++ b/tests/integrations/langchain/test_langchain.py @@ -1,4 +1,5 @@ from typing import List, Optional, Any, Iterator +from unittest import mock from unittest.mock import Mock import pytest @@ -12,7 +13,7 @@ # Langchain < 0.2 from langchain_community.chat_models import ChatOpenAI -from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.callbacks import BaseCallbackManager, CallbackManagerForLLMRun from langchain_core.messages import BaseMessage, AIMessageChunk from langchain_core.outputs import ChatGenerationChunk, ChatResult from langchain_core.runnables import RunnableConfig @@ -428,3 +429,131 @@ def test_span_map_is_instance_variable(): assert ( callback1.span_map is not callback2.span_map ), "span_map should be an instance variable, not shared between instances" + + +def test_langchain_callback_manager(sentry_init): + sentry_init( + integrations=[LangchainIntegration()], + traces_sample_rate=1.0, + ) + local_manager = BaseCallbackManager(handlers=[]) + + with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module: + mock_configure = mock_manager_module._configure + + # Explicitly re-run setup_once, so that mock_manager_module._configure gets patched + LangchainIntegration.setup_once() + + callback_manager_cls = Mock() + + mock_manager_module._configure( + callback_manager_cls, local_callbacks=local_manager + ) + + assert mock_configure.call_count == 1 + + call_args = mock_configure.call_args + assert call_args.args[0] is callback_manager_cls + + passed_manager = call_args.args[2] + assert passed_manager is not local_manager + assert local_manager.handlers == [] + + [handler] = passed_manager.handlers + assert isinstance(handler, SentryLangchainCallback) + + +def test_langchain_callback_manager_with_sentry_callback(sentry_init): + sentry_init( + integrations=[LangchainIntegration()], + traces_sample_rate=1.0, + ) + sentry_callback = SentryLangchainCallback(0, False) + local_manager = BaseCallbackManager(handlers=[sentry_callback]) + + with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module: + mock_configure = mock_manager_module._configure + + # Explicitly re-run setup_once, so that mock_manager_module._configure gets patched + LangchainIntegration.setup_once() + + callback_manager_cls = Mock() + + mock_manager_module._configure( + callback_manager_cls, local_callbacks=local_manager + ) + + assert mock_configure.call_count == 1 + + call_args = mock_configure.call_args + assert call_args.args[0] is callback_manager_cls + + passed_manager = call_args.args[2] + assert passed_manager is local_manager + + [handler] = passed_manager.handlers + assert handler is sentry_callback + + +def test_langchain_callback_list(sentry_init): + sentry_init( + integrations=[LangchainIntegration()], + traces_sample_rate=1.0, + ) + local_callbacks = [] + + with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module: + mock_configure = mock_manager_module._configure + + # Explicitly re-run setup_once, so that mock_manager_module._configure gets patched + LangchainIntegration.setup_once() + + callback_manager_cls = Mock() + + mock_manager_module._configure( + callback_manager_cls, local_callbacks=local_callbacks + ) + + assert mock_configure.call_count == 1 + + call_args = mock_configure.call_args + assert call_args.args[0] is callback_manager_cls + + passed_callbacks = call_args.args[2] + assert passed_callbacks is not local_callbacks + assert local_callbacks == [] + + [handler] = passed_callbacks + assert isinstance(handler, SentryLangchainCallback) + + +def test_langchain_callback_list_existing_callback(sentry_init): + sentry_init( + integrations=[LangchainIntegration()], + traces_sample_rate=1.0, + ) + sentry_callback = SentryLangchainCallback(0, False) + local_callbacks = [sentry_callback] + + with mock.patch("sentry_sdk.integrations.langchain.manager") as mock_manager_module: + mock_configure = mock_manager_module._configure + + # Explicitly re-run setup_once, so that mock_manager_module._configure gets patched + LangchainIntegration.setup_once() + + callback_manager_cls = Mock() + + mock_manager_module._configure( + callback_manager_cls, local_callbacks=local_callbacks + ) + + assert mock_configure.call_count == 1 + + call_args = mock_configure.call_args + assert call_args.args[0] is callback_manager_cls + + passed_callbacks = call_args.args[2] + assert passed_callbacks is local_callbacks + + [handler] = passed_callbacks + assert handler is sentry_callback