Skip to content

Commit 8c8343c

Browse files
authored
[ML][Pipelines] feat: handle corner serialization issues for internal components (Azure#28390)
* feat: handle corner case serialization for internal components * feat: support attr datatransfer in internal components * fix: bandit * refactor: extract safe load with base resolver to a util * ci: fix ci
1 parent 8354b44 commit 8c8343c

File tree

17 files changed

+1443
-333
lines changed

17 files changed

+1443
-333
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# ---------------------------------------------------------
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# ---------------------------------------------------------
4-
# pylint: disable=unused-import
5-
from ._util import enable_internal_components_in_pipeline
4+
5+
from ._setup import enable_internal_components_in_pipeline
66
from .entities import (
77
Ae365exepool,
88
AISuperComputerConfiguration,

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_schema/component.py

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
# ---------------------------------------------------------
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# ---------------------------------------------------------
4+
import os.path
45

6+
import pydash
57
from marshmallow import EXCLUDE, INCLUDE, fields, post_dump, pre_load
68

79
from azure.ai.ml._schema import NestedField, StringTransformedEnum, UnionField
810
from azure.ai.ml._schema.component.component import ComponentSchema
911
from azure.ai.ml._schema.core.fields import ArmVersionedStr, CodeField
10-
from azure.ai.ml.constants._common import LABELLED_RESOURCE_NAME, AzureMLResourceType
12+
from azure.ai.ml.constants._common import LABELLED_RESOURCE_NAME, AzureMLResourceType, SOURCE_PATH_CONTEXT_KEY
1113

14+
from .._utils import yaml_safe_load_with_base_resolver
1215
from ..._utils._arm_id_utils import parse_name_label
1316
from .environment import InternalEnvironmentSchema
1417
from .input_output import (
@@ -110,29 +113,24 @@ def _serialize(self, obj, *, many: bool = False):
110113
ret[attr_name] = self.get_attribute(obj, attr_name, None)
111114
return ret
112115

113-
@pre_load()
114-
def convert_input_value_to_str(self, data: dict, **kwargs) -> dict:
115-
"""
116-
Convert the non-str value in input to str.
117-
118-
When load the v1.5 component yaml, true/false will be converted to bool type and yyyy-mm-dd will be
119-
converted to date type. In order to be consistent with before, it needs to be converted to str type.
120-
"""
121-
def convert_to_str(value):
122-
if isinstance(value, bool):
123-
return str(value).lower()
124-
return str(value)
125-
126-
if "inputs" in data and isinstance(data["inputs"], dict):
127-
for input_port in data["inputs"].values():
128-
input_type = input_port["type"]
129-
# input type can be a list for internal component
130-
if isinstance(input_type, str) and input_type.lower() in ["string", "enum"]:
131-
if not isinstance(input_port.get("default", ""), str):
132-
input_port["default"] = convert_to_str(input_port["default"])
133-
if "enum" in input_port and any([not isinstance(item, str) for item in input_port["enum"]]):
134-
input_port["enum"] = [convert_to_str(item) for item in input_port["enum"]]
135-
return data
116+
# override param_override to ensure that param override happens after reloading the yaml
117+
@pre_load
118+
def add_param_overrides(self, data, **kwargs):
119+
source_path = self.context.pop(SOURCE_PATH_CONTEXT_KEY, None)
120+
if isinstance(data, dict) and source_path and os.path.isfile(source_path):
121+
# do override here
122+
with open(source_path, "r") as f:
123+
origin_data = yaml_safe_load_with_base_resolver(f)
124+
dot_keys = ["version"]
125+
for input_key in data.get("inputs", {}).keys():
126+
# Keep value in float input as string to avoid precision issue.
127+
for attr_name in ["default", "enum", "min", "max"]:
128+
dot_keys.append(f"inputs.{input_key}.{attr_name}")
129+
130+
for dot_key in dot_keys:
131+
if pydash.has(data, dot_key) and pydash.has(origin_data, dot_key):
132+
pydash.set_(data, dot_key, pydash.get(origin_data, dot_key))
133+
return super().add_param_overrides(data, **kwargs)
136134

137135
@post_dump(pass_original=True)
138136
def simplify_input_output_port(self, data, original, **kwargs): # pylint:disable=unused-argument, no-self-use
@@ -143,13 +141,6 @@ def simplify_input_output_port(self, data, original, **kwargs): # pylint:disabl
143141

144142
# hack, to match current serialization match expectation
145143
for port_name, port_definition in data["inputs"].items():
146-
input_type = port_definition.get("type", None)
147-
is_float_type = isinstance(input_type, str) and input_type.lower() == "float"
148-
for key in ["default", "min", "max"]:
149-
if key in port_definition:
150-
value = getattr(original.inputs[port_name], key)
151-
# Keep value in float input as string to avoid precision issue.
152-
data["inputs"][port_name][key] = str(value) if is_float_type else value
153144
if "mode" in port_definition:
154145
del port_definition["mode"]
155146

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# ---------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# ---------------------------------------------------------
4+
5+
from ._yaml_utils import yaml_safe_load_with_base_resolver
6+
7+
__all__ = [
8+
"yaml_safe_load_with_base_resolver"
9+
]
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# ---------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# ---------------------------------------------------------
4+
5+
6+
import strictyaml
7+
8+
9+
class _SafeLoaderWithBaseLoader(strictyaml.ruamel.SafeLoader):
10+
"""This is a SafeLoader with base resolver instead of version default resolver.
11+
12+
Differences between BaseResolver and VersionedResolver:
13+
1) BaseResolver won't try to resolve node value. For example, "yes" and "no" will be resolved to "true"(bool)
14+
and "false"(bool) by VersionedResolver, but won't be resolved by BaseResolver.
15+
2) VersionedResolver will delay loading the pattern matching rules to pass yaml versions on loading.
16+
17+
Given SafeLoader inherits from VersionedResolver, we can't directly remove VersionedResolver
18+
from the inheritance list. Instead, we overwrite add_version_implicit_resolver method to make
19+
_SafeLoaderWithBaseLoader._version_implicit_resolver empty. Then the resolver will act like a BaseResolver.
20+
"""
21+
def fetch_comment(self, comment):
22+
pass
23+
24+
def add_version_implicit_resolver(self, version, tag, regexp, first):
25+
"""Overwrite the method to make the resolver act like a base resolver instead of version default resolver.
26+
27+
:param version: version of yaml, like (1, 1)(yaml 1.1) and (1, 2)(yaml 1.2)
28+
:type version: VersionType
29+
:param tag: a tag indicating the type of the resolved node, e.g., tag:yaml.org,2002:bool.
30+
:param regexp: the regular expression to match the node to be resolved
31+
:param first: a list of first characters to match
32+
"""
33+
self._version_implicit_resolver.setdefault(version, {})
34+
35+
36+
def yaml_safe_load_with_base_resolver(stream):
37+
"""Load yaml string with base resolver instead of version default resolver.
38+
For example:
39+
1) "yes" and "no" will be loaded as "yes"(string) and "no"(string) instead of "true"(bool) and "false"(bool);
40+
2) "0.10" will be loaded as "0.10"(string) instead of "0.1"(float).
41+
3) "2019-01-01" will be loaded as "2019-01-01"(string) instead of "2019-01-01T00:00:00Z"(datetime).
42+
4) "1" will be loaded as "1"(string) instead of "1"(int).
43+
5) "1.0" will be loaded as "1.0"(string) instead of "1.0"(float).
44+
6) "~" will be loaded as "~"(string) instead of "None"(NoneType).
45+
Please refer to strictyaml.ruamel.resolver.implicit_resolvers for more details.
46+
"""
47+
return strictyaml.ruamel.load(stream, Loader=_SafeLoaderWithBaseLoader)

sdk/ml/azure-ai-ml/azure/ai/ml/_internal/entities/component.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def __init__(
9696
starlite: Optional[Dict] = None,
9797
ae365exepool: Optional[Dict] = None,
9898
launcher: Optional[Dict] = None,
99+
datatransfer: Optional[Dict] = None,
99100
**kwargs,
100101
):
101102
type, self._type_label = parse_name_label(type)
@@ -135,6 +136,7 @@ def __init__(
135136
self.starlite = starlite
136137
self.ae365exepool = ae365exepool
137138
self.launcher = launcher
139+
self.datatransfer = datatransfer
138140

139141
@classmethod
140142
def _build_io(cls, io_dict: Union[Dict, Input, Output], is_input: bool):

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_component/component.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
REGISTRY_URI_FORMAT,
2727
CommonYamlFields,
2828
AzureMLResourceType,
29+
SOURCE_PATH_CONTEXT_KEY,
2930
)
3031
from azure.ai.ml.constants._component import ComponentSource, NodeType, IOConstants
3132
from azure.ai.ml.entities._assets import Code
@@ -328,6 +329,7 @@ def _load(
328329
create_schema_func(
329330
{
330331
BASE_PATH_CONTEXT_KEY: base_path,
332+
SOURCE_PATH_CONTEXT_KEY: yaml_path,
331333
PARAMS_OVERRIDE_KEY: params_override,
332334
}
333335
).load(data, unknown=INCLUDE, **kwargs)

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_inputs_outputs/input.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,4 +486,4 @@ def _map_from_rest_type(cls, _type):
486486
def _from_rest_object(cls, obj: Dict) -> "Input":
487487
obj["type"] = cls._map_from_rest_type(obj["type"])
488488

489-
return Input(**obj)
489+
return cls(**obj)

sdk/ml/azure-ai-ml/tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,7 @@ def disable_internal_components():
854854
and enable_private_preview_features, as the execution order of fixtures is not guaranteed.
855855
"""
856856
from azure.ai.ml._internal._schema.component import NodeType
857-
from azure.ai.ml._internal._util import _set_registered
857+
from azure.ai.ml._internal._setup import _set_registered
858858
from azure.ai.ml.entities._component.component_factory import component_factory
859859
from azure.ai.ml.entities._job.pipeline._load_component import pipeline_node_factory
860860

sdk/ml/azure-ai-ml/tests/internal/e2etests/test_component.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def test_component_load(
8484
randstr: Callable[[str], str],
8585
yaml_path: str,
8686
) -> None:
87+
if "ae365" not in yaml_path:
88+
return
8789
omit_fields = ["id", "creation_context", "code", "name"]
8890
component_name = randstr("component_name")
8991

@@ -101,6 +103,12 @@ def test_component_load(
101103
expected_dict = json.load(f)
102104
expected_dict["_source"] = "REMOTE.WORKSPACE.COMPONENT"
103105

106+
# default value for datatransfer
107+
if expected_dict["type"] == "DataTransferComponent" and "datatransfer" not in expected_dict:
108+
expected_dict["datatransfer"] = {
109+
'allow_overwrite': 'True'
110+
}
111+
104112
# TODO: check if loaded environment is expected to be an ordered dict
105113
assert pydash.omit(loaded_dict, *omit_fields) == pydash.omit(expected_dict, *omit_fields)
106114

0 commit comments

Comments
 (0)