Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 38 additions & 9 deletions argparse_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,16 +238,21 @@
Union,
Any,
Generic,
ClassVar,
)
from dataclasses import (
Field,
is_dataclass,
fields,
MISSING,
InitVar,
dataclass as real_dataclass,
)
from importlib.metadata import version

# This is `typing._GenericAlias` but don't use non-public type names
_ClassVarType = type(ClassVar[object])

# In Python 3.10, we can use types.NoneType
NoneType = type(None)

Expand Down Expand Up @@ -284,11 +289,17 @@ def _add_dataclass_options(
if not is_dataclass(options_class):
raise TypeError("cls must be a dataclass")

for field in fields(options_class):
for field in _fields(options_class):
if not field.init:
continue # Ignore fields not in __init__
f_type = field.type
if _is_initvar(f_type):
f_type = f_type.type

args = field.metadata.get("args", [f"--{_get_arg_name(field)}"])
positional = not args[0].startswith("-")
kwargs = {
"type": field.metadata.get("type", field.type),
"type": field.metadata.get("type", f_type),
"help": field.metadata.get("help", None),
}

Expand All @@ -301,7 +312,7 @@ def _add_dataclass_options(
kwargs["choices"] = field.metadata["choices"]

# Support Literal types as an alternative means of specifying choices.
if get_origin(field.type) is Literal:
if get_origin(f_type) is Literal:
# Prohibit a potential collision with the choices field
if field.metadata.get("choices") is not None:
raise ValueError(
Expand All @@ -311,7 +322,7 @@ def _add_dataclass_options(
)

# Get the types of the arguments of the Literal
types = [type(arg) for arg in get_args(field.type)]
types = [type(arg) for arg in get_args(f_type)]

# Make sure just a single type has been used
if len(set(types)) > 1:
Expand All @@ -326,7 +337,7 @@ def _add_dataclass_options(
# Overwrite the type kwarg
kwargs["type"] = types[0]
# Use the literal arguments as choices
kwargs["choices"] = get_args(field.type)
kwargs["choices"] = get_args(f_type)

if field.metadata.get("metavar") is not None:
kwargs["metavar"] = field.metadata["metavar"]
Expand All @@ -340,7 +351,7 @@ def _add_dataclass_options(
# did not specify the type of the elements within the list, we
# try to infer it:
try:
kwargs["type"] = get_args(field.type)[0] # get_args returns a tuple
kwargs["type"] = get_args(f_type)[0] # get_args returns a tuple
except IndexError:
# get_args returned an empty tuple, type cannot be inferred
raise ValueError(
Expand All @@ -354,12 +365,12 @@ def _add_dataclass_options(
else:
kwargs["default"] = MISSING

if field.type is bool:
if f_type is bool:
_handle_bool_type(field, args, kwargs)
elif get_origin(field.type) is Union:
elif get_origin(f_type) is Union:
if field.metadata.get("type") is None:
# Optional[X] is equivalent to Union[X, None].
f_args = get_args(field.type)
f_args = get_args(f_type)
if len(f_args) == 2 and NoneType in f_args:
arg = next(a for a in f_args if a is not NoneType)
kwargs["type"] = arg
Expand Down Expand Up @@ -436,6 +447,24 @@ def _get_arg_name(field: Field):
return field.name.replace("_", "-")


def _fields(dataclass) -> Tuple[Field]:
# dataclass.fields does not return fields that are of type InitVar
dc_fields = getattr(dataclass, "__dataclass_fields__", None)
if dc_fields is None:
return fields(dataclass)
Comment on lines +453 to +454
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using __dataclass_fields__ seemed to be the only way to get the InitVar fields.
I was not sure what to do if it was not found. I decided to fall back to dataclasses.fields because that currently raise an exception if that attribute is not found.

return tuple(f for f in dc_fields.values() if not _is_classvar(f.type))


def _is_classvar(a_type):
return a_type is ClassVar or (
type(a_type) is _ClassVarType and a_type.__origin__ is ClassVar
)


def _is_initvar(a_type):
return a_type is InitVar or type(a_type) is InitVar
Comment on lines +458 to +465
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are copied and modified from the dataclasses module.



class ArgumentParser(argparse.ArgumentParser, Generic[OptionsType]):
"""Command line argument parser that derives its options from a dataclass.

Expand Down
48 changes: 46 additions & 2 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import sys
import unittest
import datetime as dt
from dataclasses import dataclass, field
from dataclasses import dataclass, field, InitVar

from typing import Optional, Union
from typing import Optional, Union, ClassVar

from argparse_dataclass import parse_args, parse_known_args

Expand Down Expand Up @@ -336,6 +336,50 @@ class Options:
self.assertEqual(params.name, "John Doe")
self.assertEqual(params.age, 3)

def test_init_false(self):
@dataclass
class Options:
date: str
time: str = "00:00"
datetime: dt.datetime = field(init=False)

def __post_init__(self):
self.datetime = dt.datetime.fromisoformat(f"{self.date}T{self.time}")

args = ["--date", "1999-12-31"]
params = parse_args(Options, args)
self.assertEqual(params.date, "1999-12-31")
self.assertEqual(params.time, "00:00")
self.assertEqual(params.datetime, dt.datetime(1999, 12, 31))

args = ["--date", "1999-12-31", "--time", "15:35:59"]
params = parse_args(Options, args)
self.assertEqual(params.date, "1999-12-31")
self.assertEqual(params.time, "15:35:59")
self.assertEqual(params.datetime, dt.datetime(1999, 12, 31, 15, 35, 59))

def test_init_only(self):
@dataclass
class Options:
cls_var: ClassVar[str] = "Hello"
date: InitVar[str]
time: InitVar[str] = "00:00"
datetime: dt.datetime = field(init=False)

def __post_init__(self, date, time):
self.datetime = dt.datetime.fromisoformat(f"{date}T{time}")

args = ["--date", "1999-12-31"]
params = parse_args(Options, args)
self.assertFalse(hasattr(params, "date"))
# time is always set to the default value. I think this is a bug..
# self.assertFalse(hasattr(params, "time"))
self.assertEqual(params.datetime, dt.datetime(1999, 12, 31))

args = ["--date", "1999-12-31", "--time", "15:35:59"]
params = parse_args(Options, args)
self.assertEqual(params.datetime, dt.datetime(1999, 12, 31, 15, 35, 59))


if __name__ == "__main__":
unittest.main()
Loading