Skip to content

Commit

Permalink
Merge branch 'main' into AIP72-connections-extra-dejson
Browse files Browse the repository at this point in the history
  • Loading branch information
amoghrajesh committed Jan 7, 2025
2 parents 5283236 + a6da8df commit 78d08ee
Show file tree
Hide file tree
Showing 10 changed files with 285 additions and 25 deletions.
1 change: 1 addition & 0 deletions task_sdk/src/airflow/sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"EdgeModifier": ".definitions.edges",
"Label": ".definitions.edges",
"Connection": ".definitions.connection",
"Variable": ".definitions.variable",
}


Expand Down
15 changes: 13 additions & 2 deletions task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,20 @@ class VariableOperations:
def __init__(self, client: Client):
self.client = client

def get(self, key: str) -> VariableResponse:
def get(self, key: str) -> VariableResponse | ErrorResponse:
"""Get a variable from the API server."""
resp = self.client.get(f"variables/{key}")
try:
resp = self.client.get(f"variables/{key}")
except ServerResponseError as e:
if e.response.status_code == HTTPStatus.NOT_FOUND:
log.error(
"Variable not found",
key=key,
detail=e.detail,
status_code=e.response.status_code,
)
return ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND, detail={"key": key})
raise
return VariableResponse.model_validate_json(resp.read())

def set(self, key: str, value: str | None, description: str | None = None):
Expand Down
41 changes: 41 additions & 0 deletions task_sdk/src/airflow/sdk/definitions/variable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from typing import Any

import attrs


@attrs.define
class Variable:
"""
A generic way to store and retrieve arbitrary content or settings as a simple key/value store.
:param key: The variable key.
:param value: The variable value.
:param description: The variable description.
"""

key: str
# keeping as any for supporting deserialize_json
value: Any | None = None
description: str | None = None

# TODO: Extend this definition for reading/writing variables without context
2 changes: 1 addition & 1 deletion task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def from_variable_response(cls, variable_response: VariableResponse) -> Variable
VariableResponse is autogenerated from the API schema, so we need to convert it to VariableResult
for communication between the Supervisor and the task process.
"""
return cls(**variable_response.model_dump())
return cls(**variable_response.model_dump(exclude_defaults=True), type="VariableResult")


class ErrorResponse(BaseModel):
Expand Down
63 changes: 61 additions & 2 deletions task_sdk/src/airflow/sdk/execution_time/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,31 @@
import structlog

from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
from airflow.sdk.types import NOTSET

if TYPE_CHECKING:
from airflow.sdk.definitions.connection import Connection
from airflow.sdk.execution_time.comms import ConnectionResult
from airflow.sdk.definitions.variable import Variable
from airflow.sdk.execution_time.comms import ConnectionResult, VariableResult


def _convert_connection_result_conn(conn_result: ConnectionResult):
def _convert_connection_result_conn(conn_result: ConnectionResult) -> Connection:
from airflow.sdk.definitions.connection import Connection

# `by_alias=True` is used to convert the `schema` field to `schema_` in the Connection model
return Connection(**conn_result.model_dump(exclude={"type"}, by_alias=True))


def _convert_variable_result_to_variable(var_result: VariableResult, deserialize_json: bool) -> Variable:
from airflow.sdk.definitions.variable import Variable

if deserialize_json:
import json

var_result.value = json.loads(var_result.value) # type: ignore
return Variable(**var_result.model_dump(exclude={"type"}))


def _get_connection(conn_id: str) -> Connection:
# TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms`
# or `airflow.sdk.execution_time.connection`
Expand All @@ -54,6 +66,26 @@ def _get_connection(conn_id: str) -> Connection:
return _convert_connection_result_conn(msg)


def _get_variable(key: str, deserialize_json: bool) -> Variable:
# TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms`
# or `airflow.sdk.execution_time.variable`
# A reason to not move it to `airflow.sdk.execution_time.comms` is that it
# will make that module depend on Task SDK, which is not ideal because we intend to
# keep Task SDK as a separate package than execution time mods.
from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

log = structlog.get_logger(logger_name="task")
SUPERVISOR_COMMS.send_request(log=log, msg=GetVariable(key=key))
msg = SUPERVISOR_COMMS.get_message()
if isinstance(msg, ErrorResponse):
raise AirflowRuntimeError(msg)

if TYPE_CHECKING:
assert isinstance(msg, VariableResult)
return _convert_variable_result_to_variable(msg, deserialize_json)


class ConnectionAccessor:
"""Wrapper to access Connection entries in template."""

Expand All @@ -76,3 +108,30 @@ def get(self, conn_id: str, default_conn: Any = None) -> Any:
if e.error.error == ErrorType.CONNECTION_NOT_FOUND:
return default_conn
raise


class VariableAccessor:
"""Wrapper to access Variable values in template."""

def __init__(self, deserialize_json: bool) -> None:
self._deserialize_json = deserialize_json

def __eq__(self, other):
if not isinstance(other, VariableAccessor):
return False
# All instances of VariableAccessor are equal since it is a stateless dynamic accessor
return True

def __repr__(self) -> str:
return "<VariableAccessor (dynamic access)>"

def __getattr__(self, key: str) -> Any:
return _get_variable(key, self._deserialize_json)

def get(self, key, default_var: Any = NOTSET) -> Any:
try:
return _get_variable(key, self._deserialize_json)
except AirflowRuntimeError as e:
if e.error.error == ErrorType.VARIABLE_NOT_FOUND:
return default_var
raise
8 changes: 6 additions & 2 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
IntermediateTIState,
TaskInstance,
TerminalTIState,
VariableResponse,
)
from airflow.sdk.execution_time.comms import (
ConnectionResult,
Expand Down Expand Up @@ -722,8 +723,11 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
resp = conn.model_dump_json().encode()
elif isinstance(msg, GetVariable):
var = self.client.variables.get(msg.key)
var_result = VariableResult.from_variable_response(var)
resp = var_result.model_dump_json().encode()
if isinstance(var, VariableResponse):
var_result = VariableResult.from_variable_response(var)
resp = var_result.model_dump_json(exclude_unset=True).encode()
else:
resp = var.model_dump_json().encode()
elif isinstance(msg, GetXCom):
xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index)
xcom_result = XComResult.from_xcom_response(xcom)
Expand Down
10 changes: 5 additions & 5 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
ToTask,
XComResult,
)
from airflow.sdk.execution_time.context import ConnectionAccessor
from airflow.sdk.execution_time.context import ConnectionAccessor, VariableAccessor

if TYPE_CHECKING:
from structlog.typing import FilteringBoundLogger as Logger
Expand Down Expand Up @@ -85,10 +85,10 @@ def get_template_context(self):
# "prev_end_date_success": get_prev_end_date_success(),
# "test_mode": task_instance.test_mode,
# "triggering_asset_events": lazy_object_proxy.Proxy(get_triggering_events),
# "var": {
# "json": VariableAccessor(deserialize_json=True),
# "value": VariableAccessor(deserialize_json=False),
# },
"var": {
"json": VariableAccessor(deserialize_json=True),
"value": VariableAccessor(deserialize_json=False),
},
"conn": ConnectionAccessor(),
}
if self._ti_context_from_server:
Expand Down
36 changes: 26 additions & 10 deletions task_sdk/tests/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,16 +362,32 @@ def handle_request(request: httpx.Request) -> httpx.Response:

client = make_client(transport=httpx.MockTransport(handle_request))

with pytest.raises(ServerResponseError) as err:
client.variables.get(key="non_existent_var")

assert err.value.response.status_code == 404
assert err.value.detail == {
"detail": {
"message": "Variable with key 'non_existent_var' not found",
"reason": "not_found",
}
}
resp = client.variables.get(key="non_existent_var")

assert isinstance(resp, ErrorResponse)
assert resp.error == ErrorType.VARIABLE_NOT_FOUND
assert resp.detail == {"key": "non_existent_var"}

@mock.patch("time.sleep", return_value=None)
def test_variable_get_500_error(self, mock_sleep):
# Simulate a response from the server returning a 500 error
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/variables/test_key":
return httpx.Response(
status_code=500,
headers=[("content-Type", "application/json")],
json={
"reason": "internal_server_error",
"message": "Internal Server Error",
},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})

client = make_client(transport=httpx.MockTransport(handle_request))
with pytest.raises(ServerResponseError):
client.variables.get(
key="test_key",
)

def test_variable_set_success(self):
# Simulate a successful response from the server when putting a variable
Expand Down
76 changes: 74 additions & 2 deletions task_sdk/tests/execution_time/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,15 @@
from unittest.mock import MagicMock, patch

from airflow.sdk.definitions.connection import Connection
from airflow.sdk.definitions.variable import Variable
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse
from airflow.sdk.execution_time.context import ConnectionAccessor, _convert_connection_result_conn
from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse, VariableResult
from airflow.sdk.execution_time.context import (
ConnectionAccessor,
VariableAccessor,
_convert_connection_result_conn,
_convert_variable_result_to_variable,
)


def test_convert_connection_result_conn():
Expand Down Expand Up @@ -50,6 +56,31 @@ def test_convert_connection_result_conn():
)


def test_convert_variable_result_to_variable():
"""Test that the VariableResult is converted to a Variable object."""
var = VariableResult(
key="test_key",
value="test_value",
)
var = _convert_variable_result_to_variable(var, deserialize_json=False)
assert var == Variable(
key="test_key",
value="test_value",
)


def test_convert_variable_result_to_variable_with_deserialize_json():
"""Test that the VariableResult is converted to a Variable object with deserialize_json set to True."""
var = VariableResult(
key="test_key",
value='{\r\n "key1": "value1",\r\n "key2": "value2",\r\n "enabled": true,\r\n "threshold": 42\r\n}',
)
var = _convert_variable_result_to_variable(var, deserialize_json=True)
assert var == Variable(
key="test_key", value={"key1": "value1", "key2": "value2", "enabled": True, "threshold": 42}
)


class TestConnectionAccessor:
def test_getattr_connection(self, mock_supervisor_comms):
"""
Expand Down Expand Up @@ -135,3 +166,44 @@ def test_getattr_connection_for_extra_dejson_decode_error(self, mock_get_logger,
mock_logger.error.assert_called_once_with(
"Failed to deserialize extra property `extra`, returning empty dictionary"
)


class TestVariableAccessor:
def test_getattr_variable(self, mock_supervisor_comms):
"""
Test that the variable is fetched when accessed via __getattr__.
"""
accessor = VariableAccessor(deserialize_json=False)

# Variable from the supervisor / API Server
var_result = VariableResult(key="test_key", value="test_value")

mock_supervisor_comms.get_message.return_value = var_result

# Fetch the variable; triggers __getattr__
var = accessor.test_key

expected_var = Variable(key="test_key", value="test_value")
assert var == expected_var

def test_get_method_valid_variable(self, mock_supervisor_comms):
"""Test that the get method returns the requested variable using `var.get`."""
accessor = VariableAccessor(deserialize_json=False)
var_result = VariableResult(key="test_key", value="test_value")

mock_supervisor_comms.get_message.return_value = var_result

var = accessor.get("test_key")
assert var == Variable(key="test_key", value="test_value")

def test_get_method_with_default(self, mock_supervisor_comms):
"""Test that the get method returns the default variable when the requested variable is not found."""

accessor = VariableAccessor(deserialize_json=False)
default_var = {"default_key": "default_value"}
error_response = ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND, detail={"test_key": "test_value"})

mock_supervisor_comms.get_message.return_value = error_response

var = accessor.get("nonexistent_var_key", default_var=default_var)
assert var == default_var
Loading

0 comments on commit 78d08ee

Please sign in to comment.