Skip to content

Commit 16570e9

Browse files
lllilithyangchenyangliao
andauthored
allow setting dict in command.identity , sweep.search_space in pipleine (Azure#27696)
* add identity setter for command, add search_space setter for sweep * resolve comments * add early_termination setter for sweep, resolve comments * fix pylint error * fix pylint error * resolve comments Co-authored-by: chenyangliao <[email protected]>
1 parent ea94f14 commit 16570e9

File tree

4 files changed

+124
-1
lines changed

4 files changed

+124
-1
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/command.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from azure.ai.ml._schema.core.fields import NestedField, UnionField
2020
from azure.ai.ml._schema.job.command_job import CommandJobSchema
2121
from azure.ai.ml._schema.job.services import JobServiceSchema
22+
from azure.ai.ml._schema.job.identity import AMLTokenIdentitySchema, ManagedIdentitySchema, UserIdentitySchema
2223
from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, LOCAL_COMPUTE_PROPERTY, LOCAL_COMPUTE_TARGET
2324
from azure.ai.ml.constants._component import ComponentSource, NodeType
2425
from azure.ai.ml.entities._assets import Environment
@@ -240,6 +241,32 @@ def resources(self, value: Union[Dict, JobResourceConfiguration]):
240241
value = JobResourceConfiguration(**value)
241242
self._resources = value
242243

244+
@property
245+
def identity(
246+
self,
247+
) -> Optional[Union[ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration]]:
248+
"""
249+
Configuration of the hyperparameter identity.
250+
"""
251+
return self._identity
252+
253+
@identity.setter
254+
def identity(self, value: Union[
255+
Dict[str, str],
256+
ManagedIdentityConfiguration,
257+
AmlTokenConfiguration,
258+
UserIdentityConfiguration, None]):
259+
if isinstance(value, dict):
260+
identity_schema = UnionField(
261+
[
262+
NestedField(ManagedIdentitySchema, unknown=INCLUDE),
263+
NestedField(AMLTokenIdentitySchema, unknown=INCLUDE),
264+
NestedField(UserIdentitySchema, unknown=INCLUDE),
265+
]
266+
)
267+
value = identity_schema._deserialize(value=value, attr=None, data=None)
268+
self._identity = value
269+
243270
@property
244271
def services(self) -> Dict:
245272
return self._services

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/sweep.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
1313
from azure.ai.ml.constants._component import NodeType
14+
from azure.ai.ml.constants._job.sweep import SearchSpace
1415
from azure.ai.ml.entities._component.command_component import CommandComponent
1516
from azure.ai.ml.entities._inputs_outputs import Input, Output
1617
from azure.ai.ml.entities._credentials import (
@@ -24,6 +25,7 @@
2425
BanditPolicy,
2526
MedianStoppingPolicy,
2627
TruncationSelectionPolicy,
28+
EarlyTerminationPolicy,
2729
)
2830
from azure.ai.ml.entities._job.sweep.objective import Objective
2931
from azure.ai.ml.entities._job.sweep.parameterized_sweep import ParameterizedSweep
@@ -41,8 +43,14 @@
4143
SweepDistribution,
4244
Uniform,
4345
)
44-
from azure.ai.ml.exceptions import ErrorTarget, UserErrorException, ValidationErrorType, ValidationException
46+
from azure.ai.ml.exceptions import (
47+
ErrorTarget,
48+
UserErrorException,
49+
ValidationErrorType,
50+
ValidationException,
51+
)
4552
from azure.ai.ml.sweep import SweepJob
53+
from azure.ai.ml._schema._sweep.sweep_fields_provider import EarlyTerminationField
4654

4755
from ..._schema import PathAwareSchema
4856
from ..._schema._utils.data_binding_expression import support_data_binding_expression_for_fields
@@ -149,6 +157,32 @@ def trial(self):
149157
"""Id or instance of the command component/job to be run for the step."""
150158
return self._component
151159

160+
@property
161+
def search_space(self):
162+
"""Dictionary of the hyperparameter search space.
163+
The key is the name of the hyperparameter and the value is the parameter expression.
164+
"""
165+
return self._search_space
166+
167+
@search_space.setter
168+
def search_space(self, values: Dict[str, Dict[str, Union[str, int, float, dict]]]):
169+
search_space = {}
170+
for name, value in values.items():
171+
# If value is a SearchSpace object, directly pass it to job.search_space[name]
172+
search_space[name] = self._value_type_to_class(value) if isinstance(value, dict) else value
173+
self._search_space = search_space
174+
175+
@classmethod
176+
def _value_type_to_class(cls, value):
177+
value_type = value['type']
178+
search_space_dict = {
179+
SearchSpace.CHOICE: Choice, SearchSpace.RANDINT: Randint, SearchSpace.LOGNORMAL: LogNormal,
180+
SearchSpace.NORMAL: Normal, SearchSpace.LOGUNIFORM: LogUniform, SearchSpace.UNIFORM: Uniform,
181+
SearchSpace.QLOGNORMAL: QLogNormal, SearchSpace.QNORMAL: QNormal, SearchSpace.QLOGUNIFORM: QLogUniform,
182+
SearchSpace.QUNIFORM: QUniform
183+
}
184+
return search_space_dict[value_type](**value)
185+
152186
@classmethod
153187
def _get_supported_inputs_types(cls):
154188
supported_types = super()._get_supported_inputs_types() or ()
@@ -320,3 +354,14 @@ def __setattr__(self, key, value):
320354
self.early_termination.slack_amount = None
321355
if self.early_termination.slack_factor == 0.0:
322356
self.early_termination.slack_factor = None
357+
358+
@property
359+
def early_termination(self) -> Union[str, EarlyTerminationPolicy]:
360+
return self._early_termination
361+
362+
@early_termination.setter
363+
def early_termination(self, value: Union[EarlyTerminationPolicy, Dict[str, Union[str, float, int, bool]]]):
364+
if isinstance(value, dict):
365+
early_termination_schema = EarlyTerminationField()
366+
value = early_termination_schema._deserialize(value=value, attr=None, data=None)
367+
self._early_termination = value

sdk/ml/azure-ai-ml/tests/component/unittests/test_command_component_entity.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,28 @@ def test_sweep_help_function(self):
367367
in std_out.getvalue()
368368
)
369369

370+
def test_sweep_early_termination_setter(self):
371+
yaml_file = "./tests/test_configs/components/helloworld_component.yml"
372+
373+
component_to_sweep: CommandComponent = load_component(source=yaml_file)
374+
cmd_node1: Command = component_to_sweep(
375+
component_in_number=Choice([2, 3, 4, 5]), component_in_path=Input(path="/a/path/on/ds")
376+
)
377+
378+
sweep_job1: Sweep = cmd_node1.sweep(
379+
primary_metric="AUC", # primary_metric,
380+
goal="maximize",
381+
sampling_algorithm="random",
382+
)
383+
sweep_job1.early_termination = {
384+
'type': "bandit", 'evaluation_interval': 100, 'delay_evaluation': 200, 'slack_factor': 40.0
385+
}
386+
from azure.ai.ml.entities._job.sweep.early_termination_policy import BanditPolicy
387+
assert isinstance(sweep_job1.early_termination, BanditPolicy)
388+
assert [sweep_job1.early_termination.evaluation_interval,
389+
sweep_job1.early_termination.delay_evaluation,
390+
sweep_job1.early_termination.slack_factor] == [100, 200, 40.0]
391+
370392
def test_invalid_component_inputs(self) -> None:
371393
yaml_path = "./tests/test_configs/components/invalid/helloworld_component_conflict_input_names.yml"
372394
component = load_component(yaml_path)

sdk/ml/azure-ai-ml/tests/dsl/unittests/test_command_builder.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,3 +970,32 @@ def my_pipeline():
970970
"type": "command",
971971
}
972972
}
973+
974+
def test_set_identity(self, test_command):
975+
from azure.ai.ml.entities._credentials import AmlTokenConfiguration
976+
node1 = test_command()
977+
node2 = node1()
978+
node2.identity = AmlTokenConfiguration()
979+
node3 = node1()
980+
node3.identity = {'type': 'AMLToken'}
981+
assert node2.identity == node3.identity
982+
983+
def test_sweep_set_search_space(self, test_command):
984+
from azure.ai.ml.entities._job.sweep.search_space import Choice
985+
node1 = test_command()
986+
command_node_to_sweep_1 = node1()
987+
sweep_node_1 = command_node_to_sweep_1.sweep(
988+
primary_metric="AUC",
989+
goal="maximize",
990+
sampling_algorithm="random",
991+
)
992+
sweep_node_1.search_space = {'batch_size': {'type': 'choice', 'values': [25, 35]}}
993+
994+
command_node_to_sweep_2 = node1()
995+
sweep_node_2 = command_node_to_sweep_2.sweep(
996+
primary_metric="AUC",
997+
goal="maximize",
998+
sampling_algorithm="random",
999+
)
1000+
sweep_node_2.search_space = {'batch_size': Choice(values=[25, 35])}
1001+
assert sweep_node_1.search_space == sweep_node_2.search_space

0 commit comments

Comments
 (0)