diff --git a/pyproject.toml b/pyproject.toml index e17f3628c9..7ff5408bbc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ description = "Typer, build great CLIs. Easy to code. Based on Python type hints authors = [ {name = "Sebastián Ramírez", email = "tiangolo@gmail.com"}, ] -requires-python = ">=3.7" +requires-python = ">=3.7" # need 3.8 for typing-extensions >=4.8.0 classifiers = [ "Intended Audience :: Information Technology", "Intended Audience :: System Administrators", @@ -34,7 +34,7 @@ classifiers = [ ] dependencies = [ "click >= 8.0.0", - "typing-extensions >= 3.7.4.3", + "typing-extensions >= 4.8.0", ] readme = "README.md" [project.urls] diff --git a/tests/test_ambiguous_params.py b/tests/test_ambiguous_params.py index 0693c8e9aa..5ca8971cab 100644 --- a/tests/test_ambiguous_params.py +++ b/tests/test_ambiguous_params.py @@ -8,7 +8,7 @@ MultipleTyperAnnotationsError, _split_annotation_from_typer_annotations, ) -from typing_extensions import Annotated +from typing_extensions import Annotated, Doc runner = CliRunner() @@ -17,10 +17,24 @@ def test_split_annotations_from_typer_annotations_simple(): # Simple sanity check that this utility works. If this isn't working on a given # python version, then no other tests for Annotated will work. given = Annotated[str, typer.Argument()] - base, typer_annotations = _split_annotation_from_typer_annotations(given) + base, typer_annotations, other_annotations = ( + _split_annotation_from_typer_annotations(given) + ) assert base is str # No equality check on the param types. Checking the length is sufficient. assert len(typer_annotations) == 1 + assert len(other_annotations) == 0 + + +def test_split_other_annotations_from_typer_annotations(): + given = Annotated[str, typer.Argument(), Doc("doc help")] + base, typer_annotations, other_annotations = ( + _split_annotation_from_typer_annotations(given) + ) + assert base is str + assert len(typer_annotations) == 1 + assert len(other_annotations) == 1 + assert isinstance(other_annotations[0], Doc) def test_forbid_default_value_in_annotated_argument(): diff --git a/tests/test_parameter_help.py b/tests/test_parameter_help.py new file mode 100644 index 0000000000..e2e2f113c0 --- /dev/null +++ b/tests/test_parameter_help.py @@ -0,0 +1,48 @@ +from typing import Annotated + +import pytest +import typer +import typer.completion +from typer import Argument, Option +from typer.testing import CliRunner +from typing_extensions import Doc + + +@pytest.mark.parametrize( + "doc,parameter,expected", + [ + (Doc("doc only help"), None, "doc only help"), + (None, Argument(help="argument only help"), "argument only help"), + ( + Doc("doc help should appear"), + Argument(), + "doc help should appear", + ), + ( + Doc("this help should not appear"), + Argument(help="argument help has priority"), + "argument help has priority", + ), + (None, Option(help="option only help"), "option only help"), + ( + Doc("this help should not appear"), + Option(help="option help has priority"), + "option help has priority", + ), + ( + Doc("doc help should appear"), + Option(), + "doc help should appear", + ), + ], +) +def test_doc_help(doc, parameter, expected): + app = typer.Typer() + + @app.command() + def main(arg: Annotated[str, doc, parameter]): + print(f"Hello {arg}") + + runner = CliRunner() + result = runner.invoke(app, ["--help"]) + assert expected in result.stdout diff --git a/typer/main.py b/typer/main.py index 36737e49ef..aa6674275c 100644 --- a/typer/main.py +++ b/typer/main.py @@ -15,7 +15,7 @@ from uuid import UUID import click -from typing_extensions import get_args, get_origin +from typing_extensions import Doc, get_args, get_origin # type: ignore from ._typing import is_union from .completion import get_completion_inspect_parameters @@ -46,7 +46,7 @@ Required, TyperInfo, ) -from .utils import get_params_from_function +from .utils import MultipleDocAnnotationsError, get_params_from_function try: import rich @@ -800,12 +800,29 @@ def lenient_issubclass( return isinstance(cls, type) and issubclass(cls, class_or_tuple) +def _set_doc_help(param: ParamMeta, parameter_info: ParameterInfo) -> None: + if not param.other_annotations: + return + doc_annotations = [ + annotation + for annotation in param.other_annotations + if isinstance(annotation, Doc) + ] + if len(doc_annotations) > 1: + raise MultipleDocAnnotationsError(param.name) + if len(doc_annotations) == 1: + doc_help = doc_annotations[0].documentation if doc_annotations else None + if not getattr(parameter_info, "help", None): + parameter_info.help = doc_help + + def get_click_param( param: ParamMeta, ) -> Tuple[Union[click.Argument, click.Option], Any]: # First, find out what will be: # * ParamInfo (ArgumentInfo or OptionInfo) # * default_value + # * help message # * required default_value = None required = False @@ -821,6 +838,7 @@ def get_click_param( else: default_value = param.default parameter_info = OptionInfo() + _set_doc_help(param, parameter_info) annotation: Any if param.annotation is not param.empty: annotation = param.annotation diff --git a/typer/models.py b/typer/models.py index 544e504761..c8155c06e9 100644 --- a/typer/models.py +++ b/typer/models.py @@ -514,10 +514,12 @@ def __init__( name: str, default: Any = inspect.Parameter.empty, annotation: Any = inspect.Parameter.empty, + other_annotations: Optional[List[Any]] = None, ) -> None: self.name = name self.default = default self.annotation = annotation + self.other_annotations = other_annotations class DeveloperExceptionConfig: diff --git a/typer/utils.py b/typer/utils.py index 93c407447e..83d8841a2a 100644 --- a/typer/utils.py +++ b/typer/utils.py @@ -63,6 +63,19 @@ def __str__(self) -> str: return msg +class MultipleDocAnnotationsError(Exception): + argument_name: str + + def __init__(self, argument_name: str): + self.argument_name = argument_name + + def __str__(self) -> str: + return ( + "Cannot specify multiple `Annotated` Doc arguments" + f" for {self.argument_name!r}" + ) + + class MultipleTyperAnnotationsError(Exception): argument_name: str @@ -94,15 +107,25 @@ def __str__(self) -> str: def _split_annotation_from_typer_annotations( base_annotation: Type[Any], -) -> Tuple[Type[Any], List[ParameterInfo]]: +) -> Tuple[Type[Any], List[ParameterInfo], List[Any]]: if get_origin(base_annotation) is not Annotated: - return base_annotation, [] - base_annotation, *maybe_typer_annotations = get_args(base_annotation) - return base_annotation, [ + return base_annotation, [], [] + base_annotation, *other_annotations = get_args(base_annotation) + typer_annotations = [ annotation - for annotation in maybe_typer_annotations + for annotation in other_annotations if isinstance(annotation, ParameterInfo) ] + other_annotations = [ + annotation + for annotation in other_annotations + if not isinstance(annotation, ParameterInfo) + ] + return ( + base_annotation, + typer_annotations, + other_annotations, + ) def get_params_from_function(func: Callable[..., Any]) -> Dict[str, ParamMeta]: @@ -114,8 +137,10 @@ def get_params_from_function(func: Callable[..., Any]) -> Dict[str, ParamMeta]: type_hints = get_type_hints(func) params = {} for param in signature.parameters.values(): - annotation, typer_annotations = _split_annotation_from_typer_annotations( - param.annotation, + annotation, typer_annotations, other_annotations = ( + _split_annotation_from_typer_annotations( + param.annotation, + ) ) if len(typer_annotations) > 1: raise MultipleTyperAnnotationsError(param.name) @@ -186,6 +211,9 @@ def get_params_from_function(func: Callable[..., Any]) -> Dict[str, ParamMeta]: default = parameter_info params[param.name] = ParamMeta( - name=param.name, default=default, annotation=annotation + name=param.name, + default=default, + annotation=annotation, + other_annotations=other_annotations, ) return params