From 370f443ddd1cabfebbbb048d86b3a5de48f6da64 Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Mon, 22 Jan 2024 14:35:06 +0530 Subject: [PATCH] Fix mypy failures --- .../aws/example_dags/example_aws_nuke.py | 2 +- .../aws/example_dags/example_sagemaker.py | 24 +++++++++---------- .../providers/core/sensors/external_task.py | 12 ++++++---- .../providers/core/sensors/filesystem.py | 2 +- 4 files changed, 21 insertions(+), 19 deletions(-) diff --git a/astronomer/providers/amazon/aws/example_dags/example_aws_nuke.py b/astronomer/providers/amazon/aws/example_dags/example_aws_nuke.py index eb6950f54..0bb3601ea 100644 --- a/astronomer/providers/amazon/aws/example_dags/example_aws_nuke.py +++ b/astronomer/providers/amazon/aws/example_dags/example_aws_nuke.py @@ -79,7 +79,7 @@ def generate_task_report(**context: Any) -> None: message=report, channel=SLACK_CHANNEL, username=SLACK_USERNAME, - ).execute(context=None) + ).execute(context={}) except Exception as exception: logging.exception("Error occur while sending slack alert.") raise exception diff --git a/astronomer/providers/amazon/aws/example_dags/example_sagemaker.py b/astronomer/providers/amazon/aws/example_dags/example_sagemaker.py index 4b74e0b67..4ab553ae2 100644 --- a/astronomer/providers/amazon/aws/example_dags/example_sagemaker.py +++ b/astronomer/providers/amazon/aws/example_dags/example_sagemaker.py @@ -338,14 +338,14 @@ def check_dag_status(**kwargs: Any) -> None: create_bucket = S3CreateBucketOperator( task_id="create_bucket", aws_conn_id=SAGEMAKER_CONN_ID, - bucket_name=test_setup["bucket_name"], + bucket_name=test_setup["bucket_name"], # type: ignore[index] ) upload_dataset = S3CreateObjectOperator( task_id="upload_dataset", aws_conn_id=SAGEMAKER_CONN_ID, - s3_bucket=test_setup["bucket_name"], - s3_key=test_setup["raw_data_s3_key_input"], + s3_bucket=test_setup["bucket_name"], # type: ignore[index] + s3_key=test_setup["raw_data_s3_key_input"], # type: ignore[index] data=DATASET, replace=True, ) @@ -353,8 +353,8 @@ def check_dag_status(**kwargs: Any) -> None: upload_training_dataset = S3CreateObjectOperator( task_id="upload_training_dataset", aws_conn_id=SAGEMAKER_CONN_ID, - s3_bucket=test_setup["bucket_name"], - s3_key=test_setup["train_data_csv"], + s3_bucket=test_setup["bucket_name"], # type: ignore[index] + s3_key=test_setup["train_data_csv"], # type: ignore[index] data=TRAIN_DATASET, replace=True, ) @@ -362,8 +362,8 @@ def check_dag_status(**kwargs: Any) -> None: upload_transform_dataset = S3CreateObjectOperator( task_id="upload_transform_dataset", aws_conn_id=SAGEMAKER_CONN_ID, - s3_bucket=test_setup["bucket_name"], - s3_key=test_setup["transform_data_csv"], + s3_bucket=test_setup["bucket_name"], # type: ignore[index] + s3_key=test_setup["transform_data_csv"], # type: ignore[index] data=TRANSFORM_DATASET, replace=True, ) @@ -371,7 +371,7 @@ def check_dag_status(**kwargs: Any) -> None: preprocess_raw_data = SageMakerProcessingOperatorAsync( task_id="preprocess_raw_data", aws_conn_id=SAGEMAKER_CONN_ID, - config=test_setup["processing_config"], + config=test_setup["processing_config"], # type: ignore[index] ) # [END howto_operator_sagemaker_processing_async] @@ -380,7 +380,7 @@ def check_dag_status(**kwargs: Any) -> None: task_id="train_model", aws_conn_id=SAGEMAKER_CONN_ID, print_log=False, - config=test_setup["training_config"], + config=test_setup["training_config"], # type: ignore[index] ) # [END howto_operator_sagemaker_training_async] @@ -388,14 +388,14 @@ def check_dag_status(**kwargs: Any) -> None: test_model = SageMakerTransformOperatorAsync( task_id="test_model", aws_conn_id=SAGEMAKER_CONN_ID, - config=test_setup["transform_config"], + config=test_setup["transform_config"], # type: ignore[index] ) # [END howto_operator_sagemaker_transform_async] delete_model = SageMakerDeleteModelOperator( task_id="delete_model", aws_conn_id=SAGEMAKER_CONN_ID, - config={"ModelName": test_setup["model_name"]}, + config={"ModelName": test_setup["model_name"]}, # type: ignore[index] trigger_rule=TriggerRule.ALL_DONE, ) @@ -403,7 +403,7 @@ def check_dag_status(**kwargs: Any) -> None: task_id="delete_bucket", aws_conn_id=SAGEMAKER_CONN_ID, trigger_rule=TriggerRule.ALL_DONE, - bucket_name=test_setup["bucket_name"], + bucket_name=test_setup["bucket_name"], # type: ignore[index] force_delete=True, ) diff --git a/astronomer/providers/core/sensors/external_task.py b/astronomer/providers/core/sensors/external_task.py index 574aa19ca..1cbb949a3 100644 --- a/astronomer/providers/core/sensors/external_task.py +++ b/astronomer/providers/core/sensors/external_task.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import datetime import warnings -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any from airflow.sensors.external_task import ExternalTaskSensor from airflow.utils.session import provide_session @@ -74,7 +76,7 @@ def execute(self, context: Context) -> None: @provide_session def execute_complete( # type: ignore[override] - self, context: Context, session: "Session", event: Optional[Dict[str, Any]] = None + self, context: Context, session: Session, event: dict[str, Any] | None = None ) -> None: """Verifies that there is a success status for each task via execution date.""" execution_dates = self.get_execution_dates(context) @@ -87,10 +89,10 @@ def execute_complete( # type: ignore[override] raise_error_or_skip_exception(self.soft_fail, error) return None - def get_execution_dates(self, context: Context) -> List[datetime.datetime]: + def get_execution_dates(self, context: Context) -> list[datetime.datetime]: """Helper function to set execution dates depending on which context and/or internal fields are populated.""" if self.execution_delta: - execution_date = context["execution_date"] - self.execution_delta + execution_date: datetime.datetime = context["execution_date"] - self.execution_delta elif self.execution_date_fn: execution_date = self._handle_execution_date_fn(context=context) else: @@ -139,7 +141,7 @@ def execute(self, context: Context) -> None: method_name="execute_complete", ) - def execute_complete(self, context: "Context", event: Optional[Dict[str, Any]] = None) -> Any: + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> Any: """ Callback for when the trigger fires - returns immediately. Return true and log the response if state is not success state raise ValueError diff --git a/astronomer/providers/core/sensors/filesystem.py b/astronomer/providers/core/sensors/filesystem.py index 39512f6e1..1ed91f13f 100644 --- a/astronomer/providers/core/sensors/filesystem.py +++ b/astronomer/providers/core/sensors/filesystem.py @@ -24,7 +24,7 @@ class FileSensorAsync(FileSensor): ``**`` in glob filepath parameter. Defaults to ``False``. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: warnings.warn( ( "This module is deprecated and will be removed in airflow>=2.9.0"