Skip to content

Commit 98dd8ec

Browse files
authored
fix: support data binding expression for resources.xxx (Azure#29559)
1 parent 9da2f42 commit 98dd8ec

File tree

18 files changed

+314
-140
lines changed

18 files changed

+314
-140
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/entities/command.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from marshmallow import INCLUDE, Schema
88

99
from ... import MpiDistribution, PyTorchDistribution, TensorFlowDistribution
10-
from ..._restclient.v2023_02_01_preview.models import JobResourceConfiguration as RestJobResourceConfiguration
1110
from ..._schema import PathAwareSchema
1211
from ..._schema.core.fields import DistributionField
1312
from ...entities import CommandJobLimits, JobResourceConfiguration
@@ -106,12 +105,11 @@ def _from_rest_object_to_init_params(cls, obj):
106105
obj = InternalBaseNode._from_rest_object_to_init_params(obj)
107106

108107
if "resources" in obj and obj["resources"]:
109-
resources = RestJobResourceConfiguration.from_dict(obj["resources"])
110-
obj["resources"] = JobResourceConfiguration._from_rest_object(resources)
108+
obj["resources"] = JobResourceConfiguration._from_rest_object(obj["resources"])
111109

112110
# handle limits
113111
if "limits" in obj and obj["limits"]:
114-
obj["limits"] = CommandJobLimits()._from_rest_object(obj["limits"])
112+
obj["limits"] = CommandJobLimits._from_rest_object(obj["limits"])
115113
return obj
116114

117115

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/_utils/data_binding_expression.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from marshmallow import Schema, fields
77

8-
from azure.ai.ml._schema.core.fields import DataBindingStr, NestedField, UnionField
8+
from azure.ai.ml._schema.core.fields import DataBindingStr, ExperimentalField, NestedField, UnionField
99
from azure.ai.ml._schema.core.schema import PathAwareSchema
1010

1111
DATA_BINDING_SUPPORTED_KEY = "_data_binding_supported"
@@ -31,6 +31,15 @@ def _add_data_binding_to_field(field, attrs_to_skip, schema_stack):
3131
elif isinstance(field, fields.List):
3232
# handle list
3333
field.inner = _add_data_binding_to_field(field.inner, attrs_to_skip, schema_stack=schema_stack)
34+
elif isinstance(field, ExperimentalField):
35+
field = ExperimentalField(
36+
_add_data_binding_to_field(field.experimental_field, attrs_to_skip, schema_stack=schema_stack),
37+
data_key=field.data_key,
38+
attribute=field.attribute,
39+
dump_only=field.dump_only,
40+
required=field.required,
41+
allow_none=field.allow_none,
42+
)
3443
elif isinstance(field, NestedField):
3544
# handle nested field
3645
support_data_binding_expression_for_fields(field.schema, attrs_to_skip, schema_stack=schema_stack)

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,12 @@
1616
from marshmallow import RAISE, fields
1717
from marshmallow.exceptions import ValidationError
1818
from marshmallow.fields import _T, Field, Nested
19-
from marshmallow.utils import (
20-
FieldInstanceResolutionError,
21-
from_iso_datetime,
22-
resolve_field_instance,
23-
)
19+
from marshmallow.utils import FieldInstanceResolutionError, from_iso_datetime, resolve_field_instance
2420

25-
from azure.ai.ml._schema.core.schema import PathAwareSchema
26-
from azure.ai.ml._utils._arm_id_utils import (
27-
AMLVersionedArmId,
28-
is_ARM_id_for_resource,
29-
parse_name_label,
30-
parse_name_version,
31-
)
32-
from azure.ai.ml._utils._experimental import _is_warning_cached
33-
from azure.ai.ml._utils.utils import (
34-
is_data_binding_expression,
35-
is_valid_node_name,
36-
load_file,
37-
load_yaml,
38-
)
39-
from azure.ai.ml.constants._common import (
21+
from ..._utils._arm_id_utils import AMLVersionedArmId, is_ARM_id_for_resource, parse_name_label, parse_name_version
22+
from ..._utils._experimental import _is_warning_cached
23+
from ..._utils.utils import is_data_binding_expression, is_valid_node_name, load_file, load_yaml
24+
from ...constants._common import (
4025
ARM_ID_PREFIX,
4126
AZUREML_RESOURCE_PROVIDER,
4227
BASE_PATH_CONTEXT_KEY,
@@ -47,16 +32,15 @@
4732
FILE_PREFIX,
4833
INTERNAL_REGISTRY_URI_FORMAT,
4934
LOCAL_COMPUTE_TARGET,
50-
SERVERLESS_COMPUTE,
5135
LOCAL_PATH,
5236
REGISTRY_URI_FORMAT,
5337
RESOURCE_ID_FORMAT,
38+
SERVERLESS_COMPUTE,
5439
AzureMLResourceType,
5540
)
56-
from azure.ai.ml.entities._job.pipeline._attr_dict import (
57-
try_get_non_arbitrary_attr_for_potential_attr_dict,
58-
)
59-
from azure.ai.ml.exceptions import ValidationException
41+
from ...entities._job.pipeline._attr_dict import try_get_non_arbitrary_attr_for_potential_attr_dict
42+
from ...exceptions import ValidationException
43+
from ..core.schema import PathAwareSchema
6044

6145
module_logger = logging.getLogger(__name__)
6246

@@ -783,6 +767,10 @@ def __init__(self, experimental_field: fields.Field, **kwargs):
783767
'"experimental_field" must be subclasses or ' "instances of marshmallow.base.FieldABC."
784768
) from error
785769

770+
@property
771+
def experimental_field(self):
772+
return self._experimental_field
773+
786774
# This sets the parent for the schema and also handles nesting.
787775
def _bind_to_schema(self, field_name, schema):
788776
super()._bind_to_schema(field_name, schema)

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/pipeline/component_job.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,46 +8,52 @@
88

99
from marshmallow import INCLUDE, ValidationError, fields, post_dump, post_load, pre_dump, validates
1010

11-
from azure.ai.ml._schema.assets.environment import AnonymousEnvironmentSchema
12-
from azure.ai.ml._schema.component import (
11+
from ..._schema.assets.environment import AnonymousEnvironmentSchema
12+
from ..._schema.component import (
1313
AnonymousCommandComponentSchema,
14+
AnonymousDataTransferCopyComponentSchema,
1415
AnonymousImportComponentSchema,
1516
AnonymousParallelComponentSchema,
1617
AnonymousSparkComponentSchema,
17-
AnonymousDataTransferCopyComponentSchema,
1818
ComponentFileRefField,
19+
DataTransferCopyComponentFileRefField,
1920
ImportComponentFileRefField,
2021
ParallelComponentFileRefField,
2122
SparkComponentFileRefField,
22-
DataTransferCopyComponentFileRefField,
2323
)
24-
from azure.ai.ml._schema.core.fields import ArmVersionedStr, NestedField, RegistryStr, UnionField
25-
from azure.ai.ml._schema.core.schema import PathAwareSchema
26-
from azure.ai.ml._schema.job.identity import AMLTokenIdentitySchema, ManagedIdentitySchema, UserIdentitySchema
27-
from azure.ai.ml._schema.job.input_output_entry import OutputSchema, DatabaseSchema, FileSystemSchema
28-
from azure.ai.ml._schema.job.input_output_fields_provider import InputsField
29-
from azure.ai.ml._schema.pipeline.pipeline_job_io import OutputBindingStr
30-
from azure.ai.ml._schema.spark_resource_configuration import SparkResourceConfigurationSchema
31-
from azure.ai.ml._utils.utils import is_data_binding_expression
32-
from azure.ai.ml.constants._common import AzureMLResourceType
33-
from azure.ai.ml.constants._component import NodeType, DataTransferTaskType
34-
from azure.ai.ml.entities._inputs_outputs import Input
35-
24+
from ..._utils.utils import is_data_binding_expression
25+
from ...constants._common import AzureMLResourceType
26+
from ...constants._component import DataTransferTaskType, NodeType
27+
from ...entities._inputs_outputs import Input
3628
from ...entities._job.pipeline._attr_dict import _AttrDict
3729
from ...exceptions import ValidationException
3830
from .._sweep.parameterized_sweep import ParameterizedSweepSchema
3931
from .._utils.data_binding_expression import support_data_binding_expression_for_fields
40-
from ..core.fields import ComputeField, StringTransformedEnum, TypeSensitiveUnionField
32+
from ..core.fields import (
33+
ArmVersionedStr,
34+
ComputeField,
35+
NestedField,
36+
RegistryStr,
37+
StringTransformedEnum,
38+
TypeSensitiveUnionField,
39+
UnionField,
40+
)
41+
from ..core.schema import PathAwareSchema
4142
from ..job import ParameterizedCommandSchema, ParameterizedParallelSchema, ParameterizedSparkSchema
43+
from ..job.identity import AMLTokenIdentitySchema, ManagedIdentitySchema, UserIdentitySchema
44+
from ..job.input_output_entry import DatabaseSchema, FileSystemSchema, OutputSchema
45+
from ..job.input_output_fields_provider import InputsField
4246
from ..job.job_limits import CommandJobLimitsSchema
4347
from ..job.parameterized_spark import SparkEntryClassSchema, SparkEntryFileSchema
4448
from ..job.services import (
4549
JobServiceSchema,
46-
SshJobServiceSchema,
4750
JupyterLabJobServiceSchema,
48-
VsCodeJobServiceSchema,
51+
SshJobServiceSchema,
4952
TensorBoardJobServiceSchema,
53+
VsCodeJobServiceSchema,
5054
)
55+
from ..pipeline.pipeline_job_io import OutputBindingStr
56+
from ..spark_resource_configuration import SparkResourceConfigurationSchema
5157

5258
module_logger = logging.getLogger(__name__)
5359

@@ -76,7 +82,11 @@ def add_user_setting_attr_dict(self, data, original_data, **kwargs): # pylint:
7682
"""Support serializing unknown fields for pipeline node."""
7783
if isinstance(original_data, _AttrDict):
7884
user_setting_attr_dict = original_data._get_attrs()
79-
data.update(user_setting_attr_dict)
85+
# TODO: dump _AttrDict values to serializable data like dict instead of original object
86+
# skip fields that are already serialized
87+
for key, value in user_setting_attr_dict.items():
88+
if key not in data:
89+
data[key] = value
8090
return data
8191

8292
# an alternative would be set schema property to be load_only, but sub-schemas like CommandSchema usually also

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/command.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515

1616
from azure.ai.ml._restclient.v2023_02_01_preview.models import CommandJob as RestCommandJob
1717
from azure.ai.ml._restclient.v2023_02_01_preview.models import JobBase
18-
from azure.ai.ml._restclient.v2023_02_01_preview.models import JobResourceConfiguration as RestJobResourceConfiguration
19-
from azure.ai.ml._restclient.v2023_02_01_preview.models import QueueSettings as RestQueueSettings
2018
from azure.ai.ml._schema.core.fields import NestedField, UnionField
2119
from azure.ai.ml._schema.job.command_job import CommandJobSchema
2220
from azure.ai.ml._schema.job.identity import AMLTokenIdentitySchema, ManagedIdentitySchema, UserIdentitySchema
@@ -44,8 +42,8 @@
4442
from azure.ai.ml.entities._job.job_limits import CommandJobLimits
4543
from azure.ai.ml.entities._job.job_resource_configuration import JobResourceConfiguration
4644
from azure.ai.ml.entities._job.job_service import (
47-
JobServiceBase,
4845
JobService,
46+
JobServiceBase,
4947
JupyterLabJobService,
5048
SshJobService,
5149
TensorBoardJobService,
@@ -591,8 +589,7 @@ def _from_rest_object_to_init_params(cls, obj: dict) -> Dict:
591589
obj = BaseNode._from_rest_object_to_init_params(obj)
592590

593591
if "resources" in obj and obj["resources"]:
594-
resources = RestJobResourceConfiguration.from_dict(obj["resources"])
595-
obj["resources"] = JobResourceConfiguration._from_rest_object(resources)
592+
obj["resources"] = JobResourceConfiguration._from_rest_object(obj["resources"])
596593

597594
# services, sweep won't have services
598595
if "services" in obj and obj["services"]:
@@ -614,8 +611,7 @@ def _from_rest_object_to_init_params(cls, obj: dict) -> Dict:
614611
obj["identity"] = _BaseJobIdentityConfiguration._load(obj["identity"])
615612

616613
if "queue_settings" in obj and obj["queue_settings"]:
617-
queue_settings = RestQueueSettings.from_dict(obj["queue_settings"])
618-
obj["queue_settings"] = QueueSettings._from_rest_object(queue_settings)
614+
obj["queue_settings"] = QueueSettings._from_rest_object(obj["queue_settings"])
619615

620616
return obj
621617

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/parallel.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,17 @@
1212
from typing import Dict, List, Optional, Union
1313

1414
from marshmallow import Schema
15-
from azure.ai.ml.constants._common import ARM_ID_PREFIX
16-
from azure.ai.ml.constants._component import NodeType
17-
from azure.ai.ml.entities._component.component import Component
18-
from azure.ai.ml.entities._component.parallel_component import ParallelComponent
19-
from azure.ai.ml.entities._inputs_outputs import Input, Output
20-
from azure.ai.ml.entities._job.job_resource_configuration import JobResourceConfiguration
21-
from azure.ai.ml.entities._job.parallel.parallel_job import ParallelJob
22-
from azure.ai.ml.entities._job.parallel.parallel_task import ParallelTask
23-
from azure.ai.ml.entities._job.parallel.retry_settings import RetrySettings
2415

2516
from ..._schema import PathAwareSchema
17+
from ...constants._common import ARM_ID_PREFIX
18+
from ...constants._component import NodeType
19+
from .._component.component import Component
20+
from .._component.parallel_component import ParallelComponent
21+
from .._inputs_outputs import Input, Output
22+
from .._job.job_resource_configuration import JobResourceConfiguration
23+
from .._job.parallel.parallel_job import ParallelJob
24+
from .._job.parallel.parallel_task import ParallelTask
25+
from .._job.parallel.retry_settings import RetrySettings
2626
from .._job.pipeline._io import NodeOutput
2727
from .._util import convert_ordered_dict_to_dict, get_rest_dict_for_node_attrs, validate_attribute_type
2828
from .base_node import BaseNode
@@ -355,7 +355,7 @@ def _from_rest_object_to_init_params(cls, obj: dict) -> Dict:
355355
obj["task"].environment = task_env[len(ARM_ID_PREFIX) :]
356356

357357
if "resources" in obj and obj["resources"]:
358-
obj["resources"] = JobResourceConfiguration._from_dict(obj["resources"])
358+
obj["resources"] = JobResourceConfiguration._from_rest_object(obj["resources"])
359359

360360
if "partition_keys" in obj and obj["partition_keys"]:
361361
obj["partition_keys"] = json.dumps(obj["partition_keys"])

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/spark.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,8 @@
1212

1313
from marshmallow import INCLUDE, Schema
1414

15-
from ..._restclient.v2023_02_01_preview.models import IdentityConfiguration
1615
from ..._restclient.v2023_02_01_preview.models import JobBase as JobBaseData
1716
from ..._restclient.v2023_02_01_preview.models import SparkJob as RestSparkJob
18-
from ..._restclient.v2023_02_01_preview.models import SparkJobEntry as RestSparkJobEntry
19-
from ..._restclient.v2023_02_01_preview.models import SparkResourceConfiguration as RestSparkResourceConfiguration
2017
from ..._schema import NestedField, PathAwareSchema, UnionField
2118
from ..._schema.job.identity import AMLTokenIdentitySchema, ManagedIdentitySchema, UserIdentitySchema
2219
from ..._schema.job.parameterized_spark import CONF_KEY_MAP, SparkConfSchema
@@ -299,16 +296,13 @@ def _from_rest_object_to_init_params(cls, obj: dict) -> Dict:
299296
obj = super()._from_rest_object_to_init_params(obj)
300297

301298
if "resources" in obj and obj["resources"]:
302-
resources = RestSparkResourceConfiguration.from_dict(obj["resources"])
303-
obj["resources"] = SparkResourceConfiguration._from_rest_object(resources)
299+
obj["resources"] = SparkResourceConfiguration._from_rest_object(obj["resources"])
304300

305301
if "identity" in obj and obj["identity"]:
306-
identity = IdentityConfiguration.from_dict(obj["identity"])
307-
obj["identity"] = _BaseJobIdentityConfiguration._from_rest_object(identity)
302+
obj["identity"] = _BaseJobIdentityConfiguration._from_rest_object(obj["identity"])
308303

309304
if "entry" in obj and obj["entry"]:
310-
entry = RestSparkJobEntry.from_dict(obj["entry"])
311-
obj["entry"] = SparkJobEntry._from_rest_object(entry)
305+
obj["entry"] = SparkJobEntry._from_rest_object(obj["entry"])
312306
if "conf" in obj and obj["conf"]:
313307
identify_schema = UnionField(
314308
[

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_credentials.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from typing import Dict, List, Optional, Union
99

1010
from azure.ai.ml._azure_environments import _get_active_directory_url_from_metadata
11-
from azure.ai.ml._restclient.v2022_12_01_preview.models import ConnectionAuthType
1211
from azure.ai.ml._restclient.v2022_01_01_preview.models import Identity as RestIdentityConfiguration
1312
from azure.ai.ml._restclient.v2022_01_01_preview.models import ManagedIdentity as RestWorkspaceConnectionManagedIdentity
1413
from azure.ai.ml._restclient.v2022_01_01_preview.models import (
@@ -24,6 +23,8 @@
2423
from azure.ai.ml._restclient.v2022_01_01_preview.models import (
2524
UsernamePassword as RestWorkspaceConnectionUsernamePassword,
2625
)
26+
from azure.ai.ml._restclient.v2022_05_01.models import ManagedServiceIdentity as RestManagedServiceIdentityConfiguration
27+
from azure.ai.ml._restclient.v2022_05_01.models import UserAssignedIdentity as RestUserAssignedIdentityConfiguration
2728
from azure.ai.ml._restclient.v2022_10_01.models import (
2829
AccountKeyDatastoreCredentials as RestAccountKeyDatastoreCredentials,
2930
)
@@ -32,7 +33,6 @@
3233
CertificateDatastoreCredentials as RestCertificateDatastoreCredentials,
3334
)
3435
from azure.ai.ml._restclient.v2022_10_01.models import CertificateDatastoreSecrets, CredentialsType
35-
from azure.ai.ml._restclient.v2022_05_01.models import ManagedServiceIdentity as RestManagedServiceIdentityConfiguration
3636
from azure.ai.ml._restclient.v2022_10_01.models import NoneDatastoreCredentials as RestNoneDatastoreCredentials
3737
from azure.ai.ml._restclient.v2022_10_01.models import SasDatastoreCredentials as RestSasDatastoreCredentials
3838
from azure.ai.ml._restclient.v2022_10_01.models import SasDatastoreSecrets as RestSasDatastoreSecrets
@@ -42,16 +42,16 @@
4242
from azure.ai.ml._restclient.v2022_10_01.models import (
4343
ServicePrincipalDatastoreSecrets as RestServicePrincipalDatastoreSecrets,
4444
)
45-
from azure.ai.ml._restclient.v2022_05_01.models import UserAssignedIdentity as RestUserAssignedIdentityConfiguration
45+
from azure.ai.ml._restclient.v2022_12_01_preview.models import ConnectionAuthType
46+
from azure.ai.ml._restclient.v2022_12_01_preview.models import (
47+
WorkspaceConnectionAccessKey as RestWorkspaceConnectionAccessKey,
48+
)
4649
from azure.ai.ml._restclient.v2023_02_01_preview.models import AmlToken as RestAmlToken
4750
from azure.ai.ml._restclient.v2023_02_01_preview.models import IdentityConfiguration as RestJobIdentityConfiguration
4851
from azure.ai.ml._restclient.v2023_02_01_preview.models import IdentityConfigurationType
4952
from azure.ai.ml._restclient.v2023_02_01_preview.models import ManagedIdentity as RestJobManagedIdentity
5053
from azure.ai.ml._restclient.v2023_02_01_preview.models import ManagedServiceIdentity as RestRegistryManagedIdentity
5154
from azure.ai.ml._restclient.v2023_02_01_preview.models import UserIdentity as RestUserIdentity
52-
from azure.ai.ml._restclient.v2022_12_01_preview.models import (
53-
WorkspaceConnectionAccessKey as RestWorkspaceConnectionAccessKey,
54-
)
5555
from azure.ai.ml._utils.utils import camel_to_snake, snake_to_pascal
5656
from azure.ai.ml.constants._common import CommonYamlFields, IdentityType
5757
from azure.ai.ml.entities._mixins import DictMixin, RestTranslatableMixin, YamlTranslatableMixin
@@ -327,12 +327,18 @@ def __init__(self):
327327

328328
@classmethod
329329
def _from_rest_object(cls, obj: RestJobIdentityConfiguration) -> "Identity":
330+
if obj is None:
331+
return None
330332
mapping = {
331333
IdentityConfigurationType.AML_TOKEN: AmlTokenConfiguration,
332334
IdentityConfigurationType.MANAGED: ManagedIdentityConfiguration,
333335
IdentityConfigurationType.USER_IDENTITY: UserIdentityConfiguration,
334336
}
335337

338+
if isinstance(obj, dict):
339+
# TODO: support data binding expression
340+
obj = RestJobIdentityConfiguration.from_dict(obj)
341+
336342
identity_class = mapping.get(obj.identity_type, None)
337343
if identity_class:
338344
# pylint: disable=protected-access

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/job_limits.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from azure.ai.ml._restclient.v2022_12_01_preview.models import SweepJobLimits as RestSweepJobLimits
1111
from azure.ai.ml._utils.utils import from_iso_duration_format, is_data_binding_expression, to_iso_duration_format
1212
from azure.ai.ml.constants import JobType
13-
from azure.ai.ml.entities._job.pipeline._io import PipelineInput
1413
from azure.ai.ml.entities._mixins import RestTranslatableMixin
1514

1615
module_logger = logging.getLogger(__name__)
@@ -44,7 +43,7 @@ def __init__(self, *, timeout: Union[int, str, None] = None):
4443
self.timeout = timeout
4544

4645
def _to_rest_object(self) -> RestCommandJobLimits:
47-
if isinstance(self.timeout, PipelineInput):
46+
if is_data_binding_expression(self.timeout):
4847
return RestCommandJobLimits(timeout=self.timeout)
4948
return RestCommandJobLimits(timeout=to_iso_duration_format(self.timeout))
5049

0 commit comments

Comments
 (0)