Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions libs/core/langchain_core/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,25 +189,26 @@ def _infer_arg_descriptions(
*,
parse_docstring: bool = False,
error_on_invalid_docstring: bool = False,
) -> tuple[str, dict]:
"""Infer argument descriptions from function docstring and annotations.

Args:
fn: The function to infer descriptions from.
parse_docstring: Whether to parse the docstring for descriptions.
error_on_invalid_docstring: Whether to raise error on invalid docstring.

Returns:
A tuple containing the function description and argument descriptions.
"""
) -> tuple[str | None, dict]:
"""Infer argument descriptions from function docstring and annotations."""
annotations = typing.get_type_hints(fn, include_extras=True)
description: str | None
arg_descriptions: dict

if parse_docstring:
description, arg_descriptions = _parse_python_function_docstring(
fn, annotations, error_on_invalid_docstring=error_on_invalid_docstring
)
else:
description = inspect.getdoc(fn) or ""
description = inspect.getdoc(fn)
arg_descriptions = {}

if inspect.isclass(fn) and description:
for parent in fn.__bases__:
if inspect.getdoc(parent) == description:
description = None
break

if parse_docstring:
_validate_docstring_args_against_annotations(arg_descriptions, annotations)
for arg, arg_type in annotations.items():
Expand Down
29 changes: 16 additions & 13 deletions libs/core/langchain_core/tools/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import inspect
import textwrap
from collections.abc import Awaitable, Callable
from inspect import signature
Expand Down Expand Up @@ -129,10 +130,10 @@ def from_function(
coroutine: Callable[..., Awaitable[Any]] | None = None,
name: str | None = None,
description: str | None = None,
return_direct: bool = False, # noqa: FBT001,FBT002
args_schema: ArgsSchema | None = None,
infer_schema: bool = True, # noqa: FBT001,FBT002
*,
return_direct: bool = False,
args_schema: ArgsSchema | None = None,
infer_schema: bool = True,
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = False,
Expand Down Expand Up @@ -189,14 +190,14 @@ def add(a: int, b: int) -> int:
raise ValueError(msg)
name = name or source_function.__name__
if args_schema is None and infer_schema:
# schema name is appended within function
args_schema = create_schema_from_function(
name,
source_function,
parse_docstring=parse_docstring,
error_on_invalid_docstring=error_on_invalid_docstring,
filter_args=_filter_schema_args(source_function),
)

description_ = description
if description is None and not parse_docstring:
description_ = source_function.__doc__ or None
Expand All @@ -213,20 +214,22 @@ def add(a: int, b: int) -> int:
elif isinstance(args_schema, dict):
description_ = args_schema.get("description")
else:
msg = (
"Invalid args_schema: expected BaseModel or dict, "
f"got {args_schema}"
)
msg = f"""Invalid args_schema: expected BaseModel or
dict, got {args_schema}"""
raise TypeError(msg)

if description_ is None:
msg = "Function must have a docstring if description not provided."
raise ValueError(msg)
if inspect.isclass(source_function) and is_basemodel_subclass(
source_function
):
description_ = ""
else:
msg = "Function must have a docstring if description not provided."
raise ValueError(msg)

if description is None:
# Only apply if using the function's docstring
description_ = textwrap.dedent(description_).strip()

# Description example:
# search_api(query: str) - Searches the API for the query.
description_ = f"{description_.strip()}"
return cls(
name=name,
Expand Down
15 changes: 15 additions & 0 deletions libs/core/tests/unit_tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2757,3 +2757,18 @@ def test_tool(
"type": "array",
}
}


def test_child_tool_does_not_inherit_docstring() -> None:
"""Test that a tool subclass does not inherit its parent's docstring."""

class MyTool(BaseModel):
"""Parent Tool."""

foo: str

@tool
class ChildTool(MyTool):
bar: str

assert ChildTool.description == "" # type: ignore[attr-defined]
Loading