diff --git a/transforms/code/repo_level_ordering/kfp_ray/repo_level_order_wf.py b/transforms/code/repo_level_ordering/kfp_ray/repo_level_order_wf.py index 42312ef3b..6c14abfd6 100644 --- a/transforms/code/repo_level_ordering/kfp_ray/repo_level_order_wf.py +++ b/transforms/code/repo_level_ordering/kfp_ray/repo_level_order_wf.py @@ -122,7 +122,14 @@ def repo_level_order( # Ray cluster ray_name: str = "repo_level_order-kfp-ray", ray_head_options: dict = {"cpu": 1, "memory": 4, "image": task_image}, - ray_worker_options: dict = {"replicas": 2, "max_replicas": 2, "min_replicas": 2, "cpu": 2, "memory": 4, "image": task_image}, + ray_worker_options: dict = { + "replicas": 2, + "max_replicas": 2, + "min_replicas": 2, + "cpu": 2, + "memory": 4, + "image": task_image, + }, server_url: str = "http://kuberay-apiserver-service.kuberay.svc.cluster.local:8888", # data access data_s3_config: str = "{'input_folder': 'test/repo_level_ordering/input', 'output_folder': 'test/repo_level_ordering/output'}", @@ -130,15 +137,15 @@ def repo_level_order( data_max_files: int = -1, data_num_samples: int = -1, # orchestrator - runtime_actor_options: dict = {'num_cpus': 0.8}, + runtime_actor_options: dict = {"num_cpus": 0.8}, runtime_pipeline_id: str = "pipeline_id", - runtime_code_location: dict = {'github': 'github', 'commit_hash': '12345', 'path': 'path'}, + runtime_code_location: dict = {"github": "github", "commit_hash": "12345", "path": "path"}, # repo_level_order parameters repo_lvl_stage_one_only: bool = False, repo_lvl_grouping_column: str = "repo_name", repo_lvl_store_type: str = "ray", repo_lvl_store_backend_dir: str = "", - repo_lvl_store_ray_cpus: float = "0.5", + repo_lvl_store_ray_cpus: float = 0.5, repo_lvl_store_ray_nworkers: int = 1, repo_lvl_sorting_enabled: bool = False, repo_lvl_sorting_algo: str = "SORT_BY_PATH", @@ -193,7 +200,9 @@ def repo_level_order( :return: None """ # create clean_up task - clean_up_task = cleanup_ray_op(ray_name=ray_name, run_id=run_id, server_url=server_url, additional_params=additional_params) + clean_up_task = cleanup_ray_op( + ray_name=ray_name, run_id=run_id, server_url=server_url, additional_params=additional_params + ) ComponentUtils.add_settings_to_component(clean_up_task, ONE_HOUR_SEC * 2) # pipeline definition with dsl.ExitHandler(clean_up_task):