diff --git a/agentlightning/config.py b/agentlightning/config.py index 65eebb60..1a0276fb 100644 --- a/agentlightning/config.py +++ b/agentlightning/config.py @@ -37,6 +37,9 @@ _C4 = TypeVar("_C4", bound=CliConfigurable) +_DEFAULT_SENTINEL = object() + + # Custom type for CLI arguments that can be string or None def nullable_str(value: str) -> str | None: """Converts specific string values (case-insensitive) to None, otherwise returns the string.""" @@ -174,6 +177,7 @@ def _add_argument_for_parameter( param_obj: inspect.Parameter, dest_name: str, resolved_param_annotation: Any = None, + provided_default: Any = _DEFAULT_SENTINEL, ) -> None: """Configures and adds a single CLI argument for an __init__ parameter.""" if resolved_param_annotation is None: @@ -189,16 +193,24 @@ def _add_argument_for_parameter( has_init_default = param_obj.default is not inspect.Parameter.empty init_default_value = param_obj.default if has_init_default else None - argparse_kwargs = _determine_argparse_type_and_nargs(core_type if is_list else param_type_annotation, is_list) + argparse_kwargs = _determine_argparse_type_and_nargs( + core_type if is_list else param_type_annotation, is_list + ) - if has_init_default: + if provided_default is not _DEFAULT_SENTINEL: + argparse_kwargs["default"] = provided_default + elif has_init_default: argparse_kwargs["default"] = init_default_value elif is_overall_optional: # Parameter is Optional (e.g. Optional[int]) and no explicit default in __init__ argparse_kwargs["default"] = None # So, if not provided on CLI, it becomes None. argparse_kwargs["help"] = _build_help_string(cls.__name__, param_name, core_type, is_overall_optional, is_list) - if not has_init_default and not is_overall_optional: # Required if no __init__ default AND not Optional + if ( + provided_default is _DEFAULT_SENTINEL + and not has_init_default + and not is_overall_optional + ): # Required if no defaults are available AND not Optional argparse_kwargs["required"] = True if "default" in argparse_kwargs: # Should not happen if logic is correct del argparse_kwargs["default"] @@ -211,6 +223,7 @@ def _add_arguments_for_class( parser: argparse.ArgumentParser, cls: Type[CliConfigurable], class_arg_configs_maps: Dict[Type[CliConfigurable], Dict[str, str]], # Maps cls to {param_name: dest_name} + provided_defaults: Dict[str, Any] | None = None, ) -> None: """Adds all relevant CLI arguments for a given class by processing its __init__ parameters.""" cls_name_lower = cls.__name__.lower() @@ -240,7 +253,18 @@ def _add_arguments_for_class( # Use the resolved hint if available, otherwise fallback to param_obj.annotation (which might be a string) actual_param_annotation = resolved_hints.get(param_name, param_obj.annotation) - _add_argument_for_parameter(parser, cls, param_name, param_obj, dest_name, actual_param_annotation) + default_override = None + if provided_defaults and param_name in provided_defaults: + default_override = provided_defaults[param_name] + _add_argument_for_parameter( + parser, + cls, + param_name, + param_obj, + dest_name, + actual_param_annotation, + default_override if provided_defaults and param_name in provided_defaults else _DEFAULT_SENTINEL, + ) def _create_argument_parser() -> argparse.ArgumentParser: @@ -294,18 +318,42 @@ def _instantiate_classes( @overload -def lightning_cli(cls1: Type[_C1]) -> _C1: ... +def lightning_cli(cls1: Type[_C1], *, defaults: Dict[Type[CliConfigurable], Dict[str, Any]] | None = None) -> _C1: ... @overload -def lightning_cli(cls1: Type[_C1], cls2: Type[_C2]) -> Tuple[_C1, _C2]: ... +def lightning_cli( + cls1: Type[_C1], + cls2: Type[_C2], + *, + defaults: Dict[Type[CliConfigurable], Dict[str, Any]] | None = None, +) -> Tuple[_C1, _C2]: ... @overload -def lightning_cli(cls1: Type[_C1], cls2: Type[_C2], cls3: Type[_C3]) -> Tuple[_C1, _C2, _C3]: ... +def lightning_cli( + cls1: Type[_C1], + cls2: Type[_C2], + cls3: Type[_C3], + *, + defaults: Dict[Type[CliConfigurable], Dict[str, Any]] | None = None, +) -> Tuple[_C1, _C2, _C3]: ... @overload -def lightning_cli(cls1: Type[_C1], cls2: Type[_C2], cls3: Type[_C3], cls4: Type[_C4]) -> Tuple[_C1, _C2, _C3, _C4]: ... +def lightning_cli( + cls1: Type[_C1], + cls2: Type[_C2], + cls3: Type[_C3], + cls4: Type[_C4], + *, + defaults: Dict[Type[CliConfigurable], Dict[str, Any]] | None = None, +) -> Tuple[_C1, _C2, _C3, _C4]: ... @overload # Fallback for more than 4 or a dynamic number of classes -def lightning_cli(*classes: Type[CliConfigurable]) -> Tuple[CliConfigurable, ...]: ... +def lightning_cli( + *classes: Type[CliConfigurable], + defaults: Dict[Type[CliConfigurable], Dict[str, Any]] | None = None, +) -> Tuple[CliConfigurable, ...]: ... -def lightning_cli(*classes: Type[CliConfigurable]) -> CliConfigurable | Tuple[CliConfigurable, ...]: +def lightning_cli( + *classes: Type[CliConfigurable], + defaults: Dict[Type[CliConfigurable], Dict[str, Any]] | None = None, +) -> CliConfigurable | Tuple[CliConfigurable, ...]: """ Parses command-line arguments to configure and instantiate provided CliConfigurable classes. @@ -325,7 +373,8 @@ def lightning_cli(*classes: Type[CliConfigurable]) -> CliConfigurable | Tuple[Cl class_arg_configs_maps: Dict[Type[CliConfigurable], Dict[str, str]] = {} for cls in classes: - _add_arguments_for_class(parser, cls, class_arg_configs_maps) + defaults_for_cls = defaults.get(cls) if defaults else None + _add_arguments_for_class(parser, cls, class_arg_configs_maps, defaults_for_cls) parsed_args = parser.parse_args() # Uses sys.argv[1:] by default diff --git a/tests/test_config.py b/tests/test_config.py index 0472e644..d6a2ce77 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -460,11 +460,13 @@ def __init__(self, param): # --- Integration Tests for lightning_cli --- -def run_lightning_cli(classes_to_configure, cli_args_list): +def run_lightning_cli(classes_to_configure, cli_args_list, defaults=None): """Helper to run lightning_cli with mocked sys.argv.""" + if defaults is None: + defaults = {} # Prepend a dummy program name to cli_args_list for sys.argv with mock.patch.object(sys, "argv", ["test_program.py"] + cli_args_list): - result = config.lightning_cli(*classes_to_configure) + result = config.lightning_cli(*classes_to_configure, defaults=defaults) if not isinstance(result, tuple): return (result,) return result @@ -595,3 +597,19 @@ def test_lightning_cli_optional_no_default_behavior(): # Provided with a value (cfg3,) = run_lightning_cli([OptionalNoDefaultConfig], ["--optionalnodefaultconfig.opt-val", "ActualValue"]) assert cfg3.opt_val == "ActualValue" + + +def test_lightning_cli_programmatic_defaults_override_required(): + """Tests that defaults passed to lightning_cli satisfy required args and can be overridden.""" + defaults = {SimpleConfig: {"name": "Provided"}} + + # No CLI args, should use provided default for required 'name' + (cfg1,) = run_lightning_cli([SimpleConfig], [], defaults=defaults) + assert cfg1.name == "Provided" + assert cfg1.value == 10 + + # CLI arg should override provided default + (cfg2,) = run_lightning_cli( + [SimpleConfig], ["--simpleconfig.name", "FromCLI"], defaults=defaults + ) + assert cfg2.name == "FromCLI"