Skip to content

Commit 0db5838

Browse files
authored
fix: HITL params not validating (#57547)
1 parent e4a60ca commit 0db5838

File tree

13 files changed

+343
-156
lines changed

13 files changed

+343
-156
lines changed

airflow-core/src/airflow/api_fastapi/core_api/datamodels/hitl.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class BaseHITLDetail(BaseModel):
5858
body: str | None = None
5959
defaults: list[str] | None = None
6060
multiple: bool = False
61-
params: dict[str, Any] = Field(default_factory=dict)
61+
params: Mapping = Field(default_factory=dict)
6262
assigned_users: list[HITLUser] = Field(default_factory=list)
6363
created_at: datetime
6464

@@ -74,7 +74,20 @@ class BaseHITLDetail(BaseModel):
7474
@classmethod
7575
def get_params(cls, params: dict[str, Any]) -> dict[str, Any]:
7676
"""Convert params attribute to dict representation."""
77-
return {k: v.dump() if getattr(v, "dump", None) else v for k, v in params.items()}
77+
return {
78+
key: value
79+
if BaseHITLDetail._is_param(value)
80+
else {
81+
"value": value,
82+
"description": None,
83+
"schema": {},
84+
}
85+
for key, value in params.items()
86+
}
87+
88+
@staticmethod
89+
def _is_param(value: Any) -> bool:
90+
return isinstance(value, dict) and all(key in value for key in ("description", "schema", "value"))
7891

7992

8093
class HITLDetail(BaseHITLDetail):

airflow-core/src/airflow/ui/src/utils/hitl.ts

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import type { TFunction } from "i18next";
2020

2121
import type { HITLDetail } from "openapi/requests/types.gen";
22-
import type { ParamsSpec } from "src/queries/useDagParams";
22+
import type { ParamSchema, ParamsSpec } from "src/queries/useDagParams";
2323

2424
export type HITLResponseParams = {
2525
chosen_options?: Array<string>;
@@ -70,7 +70,7 @@ export const getHITLParamsDict = (
7070
searchParams: URLSearchParams,
7171
): ParamsSpec => {
7272
const paramsDict: ParamsSpec = {};
73-
const { preloadedHITLOptions, preloadedHITLParams } = getPreloadHITLFormData(searchParams, hitlDetail);
73+
const { preloadedHITLOptions } = getPreloadHITLFormData(searchParams, hitlDetail);
7474
const isApprovalTask =
7575
hitlDetail.options.includes("Approve") &&
7676
hitlDetail.options.includes("Reject") &&
@@ -108,27 +108,36 @@ export const getHITLParamsDict = (
108108
const sourceParams = hitlDetail.response_received ? hitlDetail.params_input : hitlDetail.params;
109109

110110
Object.entries(sourceParams ?? {}).forEach(([key, value]) => {
111-
const valueType = typeof value === "number" ? "number" : "string";
111+
if (!hitlDetail.params) {
112+
return;
113+
}
114+
const paramData = hitlDetail.params[key] as ParamsSpec | undefined;
115+
116+
const description: string =
117+
paramData && typeof paramData.description === "string" ? paramData.description : "";
118+
119+
const schema: ParamSchema = {
120+
const: undefined,
121+
description_md: "",
122+
enum: undefined,
123+
examples: undefined,
124+
format: undefined,
125+
items: undefined,
126+
maximum: undefined,
127+
maxLength: undefined,
128+
minimum: undefined,
129+
minLength: undefined,
130+
section: undefined,
131+
title: key,
132+
type: typeof value === "number" ? "number" : "string",
133+
values_display: undefined,
134+
...(paramData?.schema && typeof paramData.schema === "object" ? paramData.schema : {}),
135+
};
112136

113137
paramsDict[key] = {
114-
description: "",
115-
schema: {
116-
const: undefined,
117-
description_md: "",
118-
enum: undefined,
119-
examples: undefined,
120-
format: undefined,
121-
items: undefined,
122-
maximum: undefined,
123-
maxLength: undefined,
124-
minimum: undefined,
125-
minLength: undefined,
126-
section: undefined,
127-
title: key,
128-
type: valueType,
129-
values_display: undefined,
130-
},
131-
value: preloadedHITLParams[key] ?? value,
138+
description,
139+
schema,
140+
value: paramData?.value ?? value,
132141
};
133142
});
134143
}

airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_hitl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def expected_sample_hitl_detail_dict(sample_ti: TaskInstance) -> dict[str, Any]:
216216
"defaults": ["Approve"],
217217
"multiple": False,
218218
"options": ["Approve", "Reject"],
219-
"params": {"input_1": 1},
219+
"params": {"input_1": {"value": 1, "schema": {}, "description": None}},
220220
"assigned_users": [],
221221
"created_at": mock.ANY,
222222
"params_input": {},
@@ -621,7 +621,7 @@ def test_should_respond_200_with_existing_response_and_concrete_query(
621621
"body": "this is body 0",
622622
"defaults": ["Approve"],
623623
"multiple": False,
624-
"params": {"input_1": 1},
624+
"params": {"input_1": {"value": 1, "schema": {}, "description": None}},
625625
"assigned_users": [],
626626
"created_at": DEFAULT_CREATED_AT.isoformat().replace("+00:00", "Z"),
627627
"responded_by_user": None,

airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2275,7 +2275,7 @@ def test_should_respond_200_with_hitl(
22752275
"defaults": ["Approve"],
22762276
"multiple": False,
22772277
"options": ["Approve", "Reject"],
2278-
"params": {"input_1": 1},
2278+
"params": {"input_1": {"value": 1, "description": None, "schema": {}}},
22792279
"params_input": {},
22802280
"responded_at": None,
22812281
"responded_by_user": None,
@@ -3554,7 +3554,7 @@ def test_should_respond_200_with_hitl(
35543554
"defaults": ["Approve"],
35553555
"multiple": False,
35563556
"options": ["Approve", "Reject"],
3557-
"params": {"input_1": 1},
3557+
"params": {"input_1": {"value": 1, "description": None, "schema": {}}},
35583558
"params_input": {},
35593559
"responded_at": None,
35603560
"responded_by_user": None,

devel-common/src/tests_common/test_utils/version_compat.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
3636
AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0)
3737
AIRFLOW_V_3_0_3_PLUS = get_base_airflow_version_tuple() >= (3, 0, 3)
3838
AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 0)
39+
AIRFLOW_V_3_1_3_PLUS = get_base_airflow_version_tuple() >= (3, 1, 3)
3940
AIRFLOW_V_3_2_PLUS = get_base_airflow_version_tuple() >= (3, 2, 0)
4041

4142

providers/standard/src/airflow/providers/standard/operators/hitl.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import logging
2020

2121
from airflow.exceptions import AirflowOptionalProviderFeatureException
22-
from airflow.providers.standard.version_compat import AIRFLOW_V_3_1_PLUS
22+
from airflow.providers.standard.version_compat import AIRFLOW_V_3_1_3_PLUS, AIRFLOW_V_3_1_PLUS
2323

2424
if not AIRFLOW_V_3_1_PLUS:
2525
raise AirflowOptionalProviderFeatureException("Human in the loop functionality needs Airflow 3.1+.")
@@ -84,6 +84,7 @@ def __init__(
8484
self.multiple = multiple
8585

8686
self.params: ParamsDict = params if isinstance(params, ParamsDict) else ParamsDict(params or {})
87+
8788
self.notifiers: Sequence[BaseNotifier] = (
8889
[notifiers] if isinstance(notifiers, BaseNotifier) else notifiers or []
8990
)
@@ -110,6 +111,7 @@ def validate_params(self) -> None:
110111
Raises:
111112
ValueError: If `"_options"` key is present in `params`, which is not allowed.
112113
"""
114+
self.params.validate()
113115
if "_options" in self.params:
114116
raise ValueError('"_options" is not allowed in params')
115117

@@ -165,8 +167,10 @@ def execute(self, context: Context):
165167
)
166168

167169
@property
168-
def serialized_params(self) -> dict[str, Any]:
169-
return self.params.dump() if isinstance(self.params, ParamsDict) else self.params
170+
def serialized_params(self) -> dict[str, dict[str, Any]]:
171+
if not AIRFLOW_V_3_1_3_PLUS:
172+
return self.params.dump() if isinstance(self.params, ParamsDict) else self.params
173+
return {k: self.params.get_param(k).serialize() for k in self.params}
170174

171175
def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
172176
if "error" in event:
@@ -196,13 +200,12 @@ def validate_chosen_options(self, chosen_options: list[str]) -> None:
196200

197201
def validate_params_input(self, params_input: Mapping) -> None:
198202
"""Check whether user provide valid params input."""
199-
if (
200-
self.serialized_params is not None
201-
and params_input is not None
202-
and set(self.serialized_params.keys()) ^ set(params_input)
203-
):
203+
if self.params and params_input and set(self.serialized_params.keys()) ^ set(params_input):
204204
raise ValueError(f"params_input {params_input} does not match params {self.params}")
205205

206+
for key, value in params_input.items():
207+
self.params[key] = value
208+
206209
def generate_link_to_ui(
207210
self,
208211
*,

0 commit comments

Comments
 (0)