Skip to content

Commit 7abb9a2

Browse files
authored
Modify runopt alias api to be less verbose
Differential Revision: D85591292 Pull Request resolved: #1155
1 parent 7f68b89 commit 7abb9a2

File tree

2 files changed

+39
-147
lines changed

2 files changed

+39
-147
lines changed

torchx/specs/api.py

Lines changed: 19 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -936,42 +936,12 @@ class runopt:
936936
Represents the metadata about the specific run option
937937
"""
938938

939-
class AutoAlias(IntEnum):
940-
snake_case = 0x1
941-
SNAKE_CASE = 0x2
942-
camelCase = 0x4
943-
944-
@staticmethod
945-
def convert_to_camel_case(alias: str) -> str:
946-
words = re.split(r"[_\-\s]+|(?<=[a-z])(?=[A-Z])", alias)
947-
words = [w for w in words if w] # Remove empty strings
948-
if not words:
949-
return ""
950-
return words[0].lower() + "".join(w.capitalize() for w in words[1:])
951-
952-
@staticmethod
953-
def convert_to_snake_case(alias: str) -> str:
954-
alias = re.sub(r"[-\s]+", "_", alias)
955-
alias = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", alias)
956-
alias = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1_\2", alias)
957-
return alias.lower()
958-
959-
@staticmethod
960-
def convert_to_const_case(alias: str) -> str:
961-
return runopt.AutoAlias.convert_to_snake_case(alias).upper()
962-
963-
class alias(str):
964-
pass
965-
966-
class deprecated(str):
967-
pass
968-
969939
default: CfgVal
970940
opt_type: Type[CfgVal]
971941
is_required: bool
972942
help: str
973-
aliases: set[alias] | None = None
974-
deprecated_aliases: set[deprecated] | None = None
943+
aliases: list[str] | None = None
944+
deprecated_aliases: list[str] | None = None
975945

976946
@property
977947
def is_type_list_of_str(self) -> bool:
@@ -1257,85 +1227,23 @@ def cfg_from_json_repr(self, json_repr: str) -> Dict[str, CfgVal]:
12571227
cfg[key] = val
12581228
return cfg
12591229

1260-
def _generate_aliases(
1261-
self, auto_alias: int, aliases: set[str]
1262-
) -> set[runopt.alias]:
1263-
generated_aliases = set()
1264-
for alias in aliases:
1265-
if auto_alias & runopt.AutoAlias.camelCase:
1266-
generated_aliases.add(runopt.AutoAlias.convert_to_camel_case(alias))
1267-
if auto_alias & runopt.AutoAlias.snake_case:
1268-
generated_aliases.add(runopt.AutoAlias.convert_to_snake_case(alias))
1269-
if auto_alias & runopt.AutoAlias.SNAKE_CASE:
1270-
generated_aliases.add(runopt.AutoAlias.convert_to_const_case(alias))
1271-
return generated_aliases
1272-
1273-
def _get_primary_key_and_aliases(
1274-
self,
1275-
cfg_key: list[str | int] | str,
1276-
) -> tuple[str, set[runopt.alias], set[runopt.deprecated]]:
1277-
"""
1278-
Returns the primary key and aliases for the given cfg_key.
1279-
"""
1280-
if isinstance(cfg_key, str):
1281-
return cfg_key, set(), set()
1282-
1283-
if len(cfg_key) == 0:
1284-
raise ValueError("cfg_key must be a non-empty list")
1285-
1286-
if isinstance(cfg_key[0], runopt.alias) or isinstance(
1287-
cfg_key[0], runopt.deprecated
1288-
):
1289-
warnings.warn(
1290-
"The main name of the run option should be the head of the list.",
1291-
UserWarning,
1292-
stacklevel=2,
1293-
)
1294-
primary_key = None
1295-
auto_alias = 0x0
1296-
aliases = set[runopt.alias]()
1297-
deprecated_aliases = set[runopt.deprecated]()
1298-
for name in cfg_key:
1299-
if isinstance(name, runopt.alias):
1300-
aliases.add(name)
1301-
elif isinstance(name, runopt.deprecated):
1302-
deprecated_aliases.add(name)
1303-
elif isinstance(name, int):
1304-
auto_alias = auto_alias | name
1305-
else:
1306-
if primary_key is not None:
1307-
raise ValueError(
1308-
f" Given more than one primary key: {primary_key}, {name}. Please use runopt.alias type for aliases. "
1309-
)
1310-
primary_key = name
1311-
if primary_key is None or primary_key == "":
1312-
raise ValueError(
1313-
"Missing cfg_key. Please provide one other than the aliases."
1314-
)
1315-
if auto_alias != 0x0:
1316-
aliases_to_generate_for = aliases | {primary_key}
1317-
additional_aliases = self._generate_aliases(
1318-
auto_alias, aliases_to_generate_for
1319-
)
1320-
aliases.update(additional_aliases)
1321-
return primary_key, aliases, deprecated_aliases
1322-
13231230
def add(
13241231
self,
1325-
cfg_key: str | list[str | int],
1232+
cfg_key: str,
13261233
type_: Type[CfgVal],
13271234
help: str,
13281235
default: CfgVal = None,
13291236
required: bool = False,
1237+
aliases: Optional[list[str]] = None,
1238+
deprecated_aliases: Optional[list[str]] = None,
13301239
) -> None:
13311240
"""
13321241
Adds the ``config`` option with the given help string and ``default``
13331242
value (if any). If the ``default`` is not specified then this option
13341243
is a required option.
13351244
"""
1336-
primary_key, aliases, deprecated_aliases = self._get_primary_key_and_aliases(
1337-
cfg_key
1338-
)
1245+
aliases = aliases or []
1246+
deprecated_aliases = deprecated_aliases or []
13391247
if required and default is not None:
13401248
raise ValueError(
13411249
f"Required option: {cfg_key} must not specify default value. Given: {default}"
@@ -1346,12 +1254,20 @@ def add(
13461254
f"Option: {cfg_key}, must be of type: {type_}."
13471255
f" Given: {default} ({type(default).__name__})"
13481256
)
1349-
opt = runopt(default, type_, required, help, aliases, deprecated_aliases)
1257+
1258+
opt = runopt(
1259+
default,
1260+
type_,
1261+
required,
1262+
help,
1263+
list(set(aliases)),
1264+
list(set(deprecated_aliases)),
1265+
)
13501266
for alias in aliases:
1351-
self._alias_to_key[alias] = primary_key
1267+
self._alias_to_key[alias] = cfg_key
13521268
for deprecated_alias in deprecated_aliases:
1353-
self._alias_to_key[deprecated_alias] = primary_key
1354-
self._opts[primary_key] = opt
1269+
self._alias_to_key[deprecated_alias] = cfg_key
1270+
self._opts[cfg_key] = opt
13551271

13561272
def update(self, other: "runopts") -> None:
13571273
self._opts.update(other._opts)

torchx/specs/test/api_test.py

Lines changed: 20 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,8 @@ def test_runopts_add(self) -> None:
605605
def test_runopts_add_with_aliases(self) -> None:
606606
opts = runopts()
607607
opts.add(
608-
["job_priority", runopt.alias("jobPriority")],
608+
"job_priority",
609+
aliases=["jobPriority"],
609610
type_=str,
610611
help="priority for the job",
611612
)
@@ -616,7 +617,8 @@ def test_runopts_add_with_aliases(self) -> None:
616617
def test_runopts_resolve_with_aliases(self) -> None:
617618
opts = runopts()
618619
opts.add(
619-
["job_priority", runopt.alias("jobPriority")],
620+
"job_priority",
621+
aliases=["jobPriority"],
620622
type_=str,
621623
help="priority for the job",
622624
)
@@ -628,71 +630,45 @@ def test_runopts_resolve_with_aliases(self) -> None:
628630
def test_runopts_resolve_with_none_valued_aliases(self) -> None:
629631
opts = runopts()
630632
opts.add(
631-
["job_priority", runopt.alias("jobPriority")],
633+
"job_priority",
634+
aliases=["jobPriority"],
632635
type_=str,
633636
help="priority for the job",
634637
)
635638
opts.add(
636-
["modelTypeName", runopt.alias("model_type_name")],
639+
"model_type_name",
640+
aliases=["modelTypeName"],
637641
type_=Union[str, None],
638642
help="ML Hub Model Type to attribute resource utilization for job",
639643
)
640-
resolved_opts = opts.resolve({"model_type_name": None, "jobPriority": "low"})
641-
self.assertEqual(resolved_opts.get("model_type_name"), None)
644+
resolved_opts = opts.resolve({"modelTypeName": None, "jobPriority": "low"})
645+
self.assertEqual(resolved_opts.get("modelTypeName"), None)
642646
self.assertEqual(resolved_opts.get("jobPriority"), "low")
643-
self.assertEqual(resolved_opts, {"model_type_name": None, "jobPriority": "low"})
647+
self.assertEqual(resolved_opts, {"modelTypeName": None, "jobPriority": "low"})
644648

645649
with self.assertRaises(InvalidRunConfigException):
646-
opts.resolve({"model_type_name": None, "modelTypeName": "low"})
650+
opts.resolve({"modelTypeName": None, "model_type_name": "low"})
647651

648652
def test_runopts_add_with_deprecated_aliases(self) -> None:
649653
opts = runopts()
650-
with warnings.catch_warnings(record=True) as w:
651-
opts.add(
652-
[runopt.deprecated("jobPriority"), "job_priority"],
653-
type_=str,
654-
help="run as user",
655-
)
656-
self.assertEqual(len(w), 1)
657-
self.assertEqual(w[0].category, UserWarning)
658-
self.assertEqual(
659-
str(w[0].message),
660-
"The main name of the run option should be the head of the list.",
661-
)
654+
opts.add(
655+
"job_priority",
656+
deprecated_aliases=["priority"],
657+
type_=str,
658+
help="run as user",
659+
)
662660

663661
opts.resolve({"job_priority": "high"})
664662
with warnings.catch_warnings(record=True) as w:
665663
warnings.simplefilter("always")
666-
opts.resolve({"jobPriority": "high"})
664+
opts.resolve({"priority": "high"})
667665
self.assertEqual(len(w), 1)
668666
self.assertEqual(w[0].category, UserWarning)
669667
self.assertEqual(
670668
str(w[0].message),
671-
"Run option `jobPriority` is deprecated, use `job_priority` instead",
669+
"Run option `priority` is deprecated, use `job_priority` instead",
672670
)
673671

674-
def test_runopt_auto_aliases(self) -> None:
675-
opts = runopts()
676-
opts.add(
677-
["job_priority", runopt.AutoAlias.camelCase],
678-
type_=str,
679-
help="run as user",
680-
)
681-
opts.add(
682-
[
683-
"model_type_name",
684-
runopt.AutoAlias.camelCase | runopt.AutoAlias.SNAKE_CASE,
685-
],
686-
type_=str,
687-
help="run as user",
688-
)
689-
self.assertEqual(2, len(opts._opts))
690-
self.assertIsNotNone(opts.get("job_priority"))
691-
self.assertIsNotNone(opts.get("jobPriority"))
692-
self.assertIsNotNone(opts.get("model_type_name"))
693-
self.assertIsNotNone(opts.get("modelTypeName"))
694-
self.assertIsNotNone(opts.get("MODEL_TYPE_NAME"))
695-
696672
def get_runopts(self) -> runopts:
697673
opts = runopts()
698674
opts.add("run_as", type_=str, help="run as user", required=True)

0 commit comments

Comments
 (0)