diff --git a/argparse_dataclass.py b/argparse_dataclass.py index 4a5baf3..ae54f78 100644 --- a/argparse_dataclass.py +++ b/argparse_dataclass.py @@ -238,16 +238,21 @@ Union, Any, Generic, + ClassVar, ) from dataclasses import ( Field, is_dataclass, fields, MISSING, + InitVar, dataclass as real_dataclass, ) from importlib.metadata import version +# This is `typing._GenericAlias` but don't use non-public type names +_ClassVarType = type(ClassVar[object]) + # In Python 3.10, we can use types.NoneType NoneType = type(None) @@ -284,11 +289,17 @@ def _add_dataclass_options( if not is_dataclass(options_class): raise TypeError("cls must be a dataclass") - for field in fields(options_class): + for field in _fields(options_class): + if not field.init: + continue # Ignore fields not in __init__ + f_type = field.type + if _is_initvar(f_type): + f_type = f_type.type + args = field.metadata.get("args", [f"--{_get_arg_name(field)}"]) positional = not args[0].startswith("-") kwargs = { - "type": field.metadata.get("type", field.type), + "type": field.metadata.get("type", f_type), "help": field.metadata.get("help", None), } @@ -301,7 +312,7 @@ def _add_dataclass_options( kwargs["choices"] = field.metadata["choices"] # Support Literal types as an alternative means of specifying choices. - if get_origin(field.type) is Literal: + if get_origin(f_type) is Literal: # Prohibit a potential collision with the choices field if field.metadata.get("choices") is not None: raise ValueError( @@ -311,7 +322,7 @@ def _add_dataclass_options( ) # Get the types of the arguments of the Literal - types = [type(arg) for arg in get_args(field.type)] + types = [type(arg) for arg in get_args(f_type)] # Make sure just a single type has been used if len(set(types)) > 1: @@ -326,7 +337,7 @@ def _add_dataclass_options( # Overwrite the type kwarg kwargs["type"] = types[0] # Use the literal arguments as choices - kwargs["choices"] = get_args(field.type) + kwargs["choices"] = get_args(f_type) if field.metadata.get("metavar") is not None: kwargs["metavar"] = field.metadata["metavar"] @@ -340,7 +351,7 @@ def _add_dataclass_options( # did not specify the type of the elements within the list, we # try to infer it: try: - kwargs["type"] = get_args(field.type)[0] # get_args returns a tuple + kwargs["type"] = get_args(f_type)[0] # get_args returns a tuple except IndexError: # get_args returned an empty tuple, type cannot be inferred raise ValueError( @@ -354,12 +365,12 @@ def _add_dataclass_options( else: kwargs["default"] = MISSING - if field.type is bool: + if f_type is bool: _handle_bool_type(field, args, kwargs) - elif get_origin(field.type) is Union: + elif get_origin(f_type) is Union: if field.metadata.get("type") is None: # Optional[X] is equivalent to Union[X, None]. - f_args = get_args(field.type) + f_args = get_args(f_type) if len(f_args) == 2 and NoneType in f_args: arg = next(a for a in f_args if a is not NoneType) kwargs["type"] = arg @@ -436,6 +447,24 @@ def _get_arg_name(field: Field): return field.name.replace("_", "-") +def _fields(dataclass) -> Tuple[Field]: + # dataclass.fields does not return fields that are of type InitVar + dc_fields = getattr(dataclass, "__dataclass_fields__", None) + if dc_fields is None: + return fields(dataclass) + return tuple(f for f in dc_fields.values() if not _is_classvar(f.type)) + + +def _is_classvar(a_type): + return a_type is ClassVar or ( + type(a_type) is _ClassVarType and a_type.__origin__ is ClassVar + ) + + +def _is_initvar(a_type): + return a_type is InitVar or type(a_type) is InitVar + + class ArgumentParser(argparse.ArgumentParser, Generic[OptionsType]): """Command line argument parser that derives its options from a dataclass. diff --git a/tests/test_functional.py b/tests/test_functional.py index ec41f9f..3340ea7 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,9 +1,9 @@ import sys import unittest import datetime as dt -from dataclasses import dataclass, field +from dataclasses import dataclass, field, InitVar -from typing import Optional, Union +from typing import Optional, Union, ClassVar from argparse_dataclass import parse_args, parse_known_args @@ -336,6 +336,50 @@ class Options: self.assertEqual(params.name, "John Doe") self.assertEqual(params.age, 3) + def test_init_false(self): + @dataclass + class Options: + date: str + time: str = "00:00" + datetime: dt.datetime = field(init=False) + + def __post_init__(self): + self.datetime = dt.datetime.fromisoformat(f"{self.date}T{self.time}") + + args = ["--date", "1999-12-31"] + params = parse_args(Options, args) + self.assertEqual(params.date, "1999-12-31") + self.assertEqual(params.time, "00:00") + self.assertEqual(params.datetime, dt.datetime(1999, 12, 31)) + + args = ["--date", "1999-12-31", "--time", "15:35:59"] + params = parse_args(Options, args) + self.assertEqual(params.date, "1999-12-31") + self.assertEqual(params.time, "15:35:59") + self.assertEqual(params.datetime, dt.datetime(1999, 12, 31, 15, 35, 59)) + + def test_init_only(self): + @dataclass + class Options: + cls_var: ClassVar[str] = "Hello" + date: InitVar[str] + time: InitVar[str] = "00:00" + datetime: dt.datetime = field(init=False) + + def __post_init__(self, date, time): + self.datetime = dt.datetime.fromisoformat(f"{date}T{time}") + + args = ["--date", "1999-12-31"] + params = parse_args(Options, args) + self.assertFalse(hasattr(params, "date")) + # time is always set to the default value. I think this is a bug.. + # self.assertFalse(hasattr(params, "time")) + self.assertEqual(params.datetime, dt.datetime(1999, 12, 31)) + + args = ["--date", "1999-12-31", "--time", "15:35:59"] + params = parse_args(Options, args) + self.assertEqual(params.datetime, dt.datetime(1999, 12, 31, 15, 35, 59)) + if __name__ == "__main__": unittest.main()