Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 101 additions & 6 deletions torchx/specs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pathlib
import re
import typing
import warnings
from dataclasses import asdict, dataclass, field
from datetime import datetime
from enum import Enum
Expand Down Expand Up @@ -891,10 +892,18 @@ class runopt:
Represents the metadata about the specific run option
"""

class alias(str):
pass

class deprecated(str):
pass

default: CfgVal
opt_type: Type[CfgVal]
is_required: bool
help: str
aliases: list[alias] | None = None
deprecated_aliases: list[deprecated] | None = None

@property
def is_type_list_of_str(self) -> bool:
Expand Down Expand Up @@ -986,6 +995,7 @@ class runopts:

def __init__(self) -> None:
self._opts: Dict[str, runopt] = {}
self._alias_to_key: dict[str, str] = {}

def __iter__(self) -> Iterator[Tuple[str, runopt]]:
return self._opts.items().__iter__()
Expand Down Expand Up @@ -1013,9 +1023,16 @@ def is_type(obj: CfgVal, tp: Type[CfgVal]) -> bool:

def get(self, name: str) -> Optional[runopt]:
"""
Returns option if any was registered, or None otherwise
Returns option if any was registered, or None otherwise.
First searches for the option by ``name``, then falls-back to matching ``name`` with any
registered aliases.

"""
return self._opts.get(name, None)
if name in self._opts:
return self._opts[name]
if name in self._alias_to_key:
return self._opts[self._alias_to_key[name]]
return None

def resolve(self, cfg: Mapping[str, CfgVal]) -> Dict[str, CfgVal]:
"""
Expand All @@ -1030,6 +1047,36 @@ def resolve(self, cfg: Mapping[str, CfgVal]) -> Dict[str, CfgVal]:

for cfg_key, runopt in self._opts.items():
val = resolved_cfg.get(cfg_key)
resolved_name = None
aliases = runopt.aliases or []
deprecated_aliases = runopt.deprecated_aliases or []
if val is None:
for alias in aliases:
val = resolved_cfg.get(alias)
if alias in cfg or val is not None:
resolved_name = alias
break
for alias in deprecated_aliases:
val = resolved_cfg.get(alias)
if val is not None:
resolved_name = alias
use_instead = self._alias_to_key.get(alias)
warnings.warn(
f"Run option `{alias}` is deprecated, use `{use_instead}` instead",
UserWarning,
stacklevel=2,
)
break
else:
resolved_name = cfg_key
for alias in aliases:
duplicate_val = resolved_cfg.get(alias)
if alias in cfg or duplicate_val is not None:
raise InvalidRunConfigException(
f"Duplicate opt name. runopt: `{resolved_name}``, is an alias of runopt: `{alias}`",
resolved_name,
cfg,
)

# check required opt
if runopt.is_required and val is None:
Expand All @@ -1049,7 +1096,7 @@ def resolve(self, cfg: Mapping[str, CfgVal]) -> Dict[str, CfgVal]:
)

# not required and not set, set to default
if val is None:
if val is None and resolved_name is None:
resolved_cfg[cfg_key] = runopt.default
return resolved_cfg

Expand Down Expand Up @@ -1142,9 +1189,50 @@ def cfg_from_json_repr(self, json_repr: str) -> Dict[str, CfgVal]:
cfg[key] = val
return cfg

def _get_primary_key_and_aliases(
self,
cfg_key: list[str] | str,
) -> tuple[str, list[runopt.alias], list[runopt.deprecated]]:
"""
Returns the primary key and aliases for the given cfg_key.
"""
if isinstance(cfg_key, str):
return cfg_key, [], []

if len(cfg_key) == 0:
raise ValueError("cfg_key must be a non-empty list")

if isinstance(cfg_key[0], runopt.alias) or isinstance(
cfg_key[0], runopt.deprecated
):
warnings.warn(
"The main name of the run option should be the head of the list.",
UserWarning,
stacklevel=2,
)
primary_key = None
aliases = list[runopt.alias]()
deprecated_aliases = list[runopt.deprecated]()
for name in cfg_key:
if isinstance(name, runopt.alias):
aliases.append(name)
elif isinstance(name, runopt.deprecated):
deprecated_aliases.append(name)
else:
if primary_key is not None:
raise ValueError(
f" Given more than one primary key: {primary_key}, {name}. Please use runopt.alias type for aliases. "
)
primary_key = name
if primary_key is None or primary_key == "":
raise ValueError(
"Missing cfg_key. Please provide one other than the aliases."
)
return primary_key, aliases, deprecated_aliases

def add(
self,
cfg_key: str,
cfg_key: str | list[str],
type_: Type[CfgVal],
help: str,
default: CfgVal = None,
Expand All @@ -1155,6 +1243,9 @@ def add(
value (if any). If the ``default`` is not specified then this option
is a required option.
"""
primary_key, aliases, deprecated_aliases = self._get_primary_key_and_aliases(
cfg_key
)
if required and default is not None:
raise ValueError(
f"Required option: {cfg_key} must not specify default value. Given: {default}"
Expand All @@ -1165,8 +1256,12 @@ def add(
f"Option: {cfg_key}, must be of type: {type_}."
f" Given: {default} ({type(default).__name__})"
)

self._opts[cfg_key] = runopt(default, type_, required, help)
opt = runopt(default, type_, required, help, aliases, deprecated_aliases)
for alias in aliases:
self._alias_to_key[alias] = primary_key
for deprecated_alias in deprecated_aliases:
self._alias_to_key[deprecated_alias] = primary_key
self._opts[primary_key] = opt

def update(self, other: "runopts") -> None:
self._opts.update(other._opts)
Expand Down
70 changes: 70 additions & 0 deletions torchx/specs/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import tempfile
import time
import unittest
import warnings
from dataclasses import asdict
from pathlib import Path
from typing import Dict, List, Mapping, Tuple, Union
Expand Down Expand Up @@ -578,6 +579,75 @@ def test_runopts_add(self) -> None:
# this print is intentional (demonstrates the intended usecase)
print(opts)

def test_runopts_add_with_aliases(self) -> None:
opts = runopts()
opts.add(
["job_priority", runopt.alias("jobPriority")],
type_=str,
help="priority for the job",
)
self.assertEqual(1, len(opts._opts))
self.assertIsNotNone(opts.get("job_priority"))
self.assertIsNotNone(opts.get("jobPriority"))

def test_runopts_resolve_with_aliases(self) -> None:
opts = runopts()
opts.add(
["job_priority", runopt.alias("jobPriority")],
type_=str,
help="priority for the job",
)
opts.resolve({"job_priority": "high"})
opts.resolve({"jobPriority": "low"})
with self.assertRaises(InvalidRunConfigException):
opts.resolve({"job_priority": "high", "jobPriority": "low"})

def test_runopts_resolve_with_none_valued_aliases(self) -> None:
opts = runopts()
opts.add(
["job_priority", runopt.alias("jobPriority")],
type_=str,
help="priority for the job",
)
opts.add(
["modelTypeName", runopt.alias("model_type_name")],
type_=Union[str, None],
help="ML Hub Model Type to attribute resource utilization for job",
)
resolved_opts = opts.resolve({"model_type_name": None, "jobPriority": "low"})
self.assertEqual(resolved_opts.get("model_type_name"), None)
self.assertEqual(resolved_opts.get("jobPriority"), "low")
self.assertEqual(resolved_opts, {"model_type_name": None, "jobPriority": "low"})

with self.assertRaises(InvalidRunConfigException):
opts.resolve({"model_type_name": None, "modelTypeName": "low"})

def test_runopts_add_with_deprecated_aliases(self) -> None:
opts = runopts()
with warnings.catch_warnings(record=True) as w:
opts.add(
[runopt.deprecated("jobPriority"), "job_priority"],
type_=str,
help="run as user",
)
self.assertEqual(len(w), 1)
self.assertEqual(w[0].category, UserWarning)
self.assertEqual(
str(w[0].message),
"The main name of the run option should be the head of the list.",
)

opts.resolve({"job_priority": "high"})
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
opts.resolve({"jobPriority": "high"})
self.assertEqual(len(w), 1)
self.assertEqual(w[0].category, UserWarning)
self.assertEqual(
str(w[0].message),
"Run option `jobPriority` is deprecated, use `job_priority` instead",
)

def get_runopts(self) -> runopts:
opts = runopts()
opts.add("run_as", type_=str, help="run as user", required=True)
Expand Down
Loading