Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-72: Exposing 'extra_dejson' on Connection definition #45448

Merged
merged 7 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 7 additions & 23 deletions task_sdk/src/airflow/sdk/definitions/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import json
import logging
from contextlib import suppress
from json import JSONDecodeError

import attrs
Expand Down Expand Up @@ -59,28 +58,13 @@ def get_uri(self): ...
def get_hook(self): ...

@property
def extra_dejson(self, nested: bool = False) -> dict:
"""
Deserialize extra property to JSON.

:param nested: Determines whether nested structures are also deserialized into JSON (default False).
"""
extra_json = {}

def extra_dejson(self) -> dict:
"""Deserialize `extra` property to JSON."""
extra = {}
if self.extra:
try:
if nested:
for key, value in json.loads(self.extra).items():
extra_json[key] = value
if isinstance(value, str):
with suppress(JSONDecodeError):
extra_json[key] = json.loads(value)
else:
extra_json = json.loads(self.extra)
extra = json.loads(self.extra)
except JSONDecodeError:
log.exception("Failed parsing the json for conn_id %s", self.conn_id)

# TODO: Mask sensitive keys from this list
# mask_secret(extra)

return extra_json
log.exception("Failed to deserialize extra property `extra`, returning empty dictionary")
# TODO: Mask sensitive keys from this list or revisit if it will be done in server
return extra
44 changes: 44 additions & 0 deletions task_sdk/tests/execution_time/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from __future__ import annotations

from unittest.mock import MagicMock, patch

from airflow.sdk.definitions.connection import Connection
from airflow.sdk.definitions.variable import Variable
from airflow.sdk.exceptions import ErrorType
Expand Down Expand Up @@ -122,6 +124,48 @@ def test_get_method_with_default(self, mock_supervisor_comms):
conn = accessor.get("nonexistent_conn", default_conn=default_conn)
assert conn == default_conn

def test_getattr_connection_for_extra_dejson(self, mock_supervisor_comms):
accessor = ConnectionAccessor()

# Conn from the supervisor / API Server
conn_result = ConnectionResult(
conn_id="mysql_conn",
conn_type="mysql",
host="mysql",
port=3306,
extra='{"extra_key": "extra_value"}',
)

mock_supervisor_comms.get_message.return_value = conn_result

# Fetch the connection's dejson; triggers __getattr__
dejson = accessor.mysql_conn.extra_dejson

assert dejson == {"extra_key": "extra_value"}

@patch("airflow.sdk.definitions.connection.log", create=True)
def test_getattr_connection_for_extra_dejson_decode_error(self, mock_log, mock_supervisor_comms):
mock_log.return_value = MagicMock()

accessor = ConnectionAccessor()

# Conn from the supervisor / API Server
conn_result = ConnectionResult(
conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306, extra="This is not JSON!"
)

mock_supervisor_comms.get_message.return_value = conn_result

# Fetch the connection's dejson; triggers __getattr__
dejson = accessor.mysql_conn.extra_dejson

# empty in case of failed deserialising
assert dejson == {}

mock_log.exception.assert_called_once_with(
"Failed to deserialize extra property `extra`, returning empty dictionary"
)


class TestVariableAccessor:
def test_getattr_variable(self, mock_supervisor_comms):
Expand Down
3 changes: 3 additions & 0 deletions task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,9 @@ def test_get_connection_from_context(self, mocked_parse, make_ti_context, mock_s
extra='{"extra_key": "extra_value"}',
)

dejson_from_conn = conn_from_context.extra_dejson
assert dejson_from_conn == {"extra_key": "extra_value"}

def test_template_render(self, mocked_parse, make_ti_context):
task = BaseOperator(task_id="test_template_render_task")

Expand Down