Skip to content

Commit d0b8030

Browse files
committed
Adding support for fields that are only included in __init__
1 parent 9bd1b7c commit d0b8030

File tree

2 files changed

+59
-11
lines changed

2 files changed

+59
-11
lines changed

argparse_dataclass.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -239,16 +239,21 @@
239239
Dict,
240240
Any,
241241
Generic,
242+
ClassVar,
242243
)
243244
from dataclasses import (
244245
Field,
245246
is_dataclass,
246247
fields,
247248
MISSING,
249+
InitVar,
248250
dataclass as real_dataclass,
249251
)
250252
from importlib.metadata import version
251253

254+
# This is `typing._GenericAlias` but don't use non-public type names
255+
_ClassVarType = type(ClassVar[object])
256+
252257
if hasattr(argparse, "BooleanOptionalAction"):
253258
# BooleanOptionalAction was added in Python 3.9
254259
BooleanOptionalAction = argparse.BooleanOptionalAction
@@ -333,14 +338,17 @@ def _add_dataclass_options(
333338
if not is_dataclass(options_class):
334339
raise TypeError("cls must be a dataclass")
335340

336-
for field in fields(options_class):
341+
for field in _fields(options_class):
337342
if not field.init:
338343
continue # Ignore fields not in __init__
344+
f_type = field.type
345+
if _is_initvar(f_type):
346+
f_type = f_type.type
339347

340348
args = field.metadata.get("args", [f"--{_get_arg_name(field)}"])
341349
positional = not args[0].startswith("-")
342350
kwargs = {
343-
"type": field.metadata.get("type", field.type),
351+
"type": field.metadata.get("type", f_type),
344352
"help": field.metadata.get("help", None),
345353
}
346354

@@ -353,7 +361,7 @@ def _add_dataclass_options(
353361
kwargs["choices"] = field.metadata["choices"]
354362

355363
# Support Literal types as an alternative means of specifying choices.
356-
if get_origin(field.type) is Literal:
364+
if get_origin(f_type) is Literal:
357365
# Prohibit a potential collision with the choices field
358366
if field.metadata.get("choices") is not None:
359367
raise ValueError(
@@ -363,7 +371,7 @@ def _add_dataclass_options(
363371
)
364372

365373
# Get the types of the arguments of the Literal
366-
types = [type(arg) for arg in get_args(field.type)]
374+
types = [type(arg) for arg in get_args(f_type)]
367375

368376
# Make sure just a single type has been used
369377
if len(set(types)) > 1:
@@ -378,7 +386,7 @@ def _add_dataclass_options(
378386
# Overwrite the type kwarg
379387
kwargs["type"] = types[0]
380388
# Use the literal arguments as choices
381-
kwargs["choices"] = get_args(field.type)
389+
kwargs["choices"] = get_args(f_type)
382390

383391
if field.metadata.get("metavar") is not None:
384392
kwargs["metavar"] = field.metadata["metavar"]
@@ -392,7 +400,7 @@ def _add_dataclass_options(
392400
# did not specify the type of the elements within the list, we
393401
# try to infer it:
394402
try:
395-
kwargs["type"] = get_args(field.type)[0] # get_args returns a tuple
403+
kwargs["type"] = get_args(f_type)[0] # get_args returns a tuple
396404
except IndexError:
397405
# get_args returned an empty tuple, type cannot be inferred
398406
raise ValueError(
@@ -406,12 +414,12 @@ def _add_dataclass_options(
406414
else:
407415
kwargs["default"] = MISSING
408416

409-
if field.type is bool:
417+
if f_type is bool:
410418
_handle_bool_type(field, args, kwargs)
411-
elif get_origin(field.type) is Union:
419+
elif get_origin(f_type) is Union:
412420
if field.metadata.get("type") is None:
413421
# Optional[X] is equivalent to Union[X, None].
414-
f_args = get_args(field.type)
422+
f_args = get_args(f_type)
415423
if len(f_args) == 2 and NoneType in f_args:
416424
arg = next(a for a in f_args if a is not NoneType)
417425
kwargs["type"] = arg
@@ -488,6 +496,24 @@ def _get_arg_name(field: Field):
488496
return field.name.replace("_", "-")
489497

490498

499+
def _fields(dataclass) -> Tuple[Field]:
500+
# dataclass.fields does not return fields that are of type InitVar
501+
dc_fields = getattr(dataclass, "__dataclass_fields__", None)
502+
if dc_fields is None:
503+
return fields(dataclass)
504+
return tuple(f for f in dc_fields.values() if not _is_classvar(f.type))
505+
506+
507+
def _is_classvar(a_type):
508+
return a_type is ClassVar or (
509+
type(a_type) is _ClassVarType and a_type.__origin__ is ClassVar
510+
)
511+
512+
513+
def _is_initvar(a_type):
514+
return a_type is InitVar or type(a_type) is InitVar
515+
516+
491517
class ArgumentParser(argparse.ArgumentParser, Generic[OptionsType]):
492518
"""Command line argument parser that derives its options from a dataclass.
493519

tests/test_functional.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import sys
22
import unittest
33
import datetime as dt
4-
from dataclasses import dataclass, field
4+
from dataclasses import dataclass, field, InitVar
55

6-
from typing import List, Optional, Union
6+
from typing import List, Optional, Union, ClassVar
77

88
from argparse_dataclass import parse_args, parse_known_args
99

@@ -358,6 +358,28 @@ def __post_init__(self):
358358
self.assertEqual(params.time, "15:35:59")
359359
self.assertEqual(params.datetime, dt.datetime(1999, 12, 31, 15, 35, 59))
360360

361+
def test_init_only(self):
362+
@dataclass
363+
class Options:
364+
cls_var: ClassVar[str] = "Hello"
365+
date: InitVar[str]
366+
time: InitVar[str] = "00:00"
367+
datetime: dt.datetime = field(init=False)
368+
369+
def __post_init__(self, date, time):
370+
self.datetime = dt.datetime.fromisoformat(f"{date}T{time}")
371+
372+
args = ["--date", "1999-12-31"]
373+
params = parse_args(Options, args)
374+
self.assertFalse(hasattr(params, "date"))
375+
# time is always set to the default value. I think this is a bug..
376+
# self.assertFalse(hasattr(params, "time"))
377+
self.assertEqual(params.datetime, dt.datetime(1999, 12, 31))
378+
379+
args = ["--date", "1999-12-31", "--time", "15:35:59"]
380+
params = parse_args(Options, args)
381+
self.assertEqual(params.datetime, dt.datetime(1999, 12, 31, 15, 35, 59))
382+
361383

362384
if __name__ == "__main__":
363385
unittest.main()

0 commit comments

Comments
 (0)