Skip to content

Commit

Permalink
Add directory transfer support for SFTPOperator (#44126)
Browse files Browse the repository at this point in the history
* Add directory put and get functions for sftp provider

* Add test code

* Add directory exists check

* Fix merge conflict

* Add path exists check
  • Loading branch information
Dawnpool authored Dec 28, 2024
1 parent 61412b3 commit 51fea3e
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 2 deletions.
46 changes: 46 additions & 0 deletions providers/src/airflow/providers/sftp/hooks/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import stat
from collections.abc import Sequence
from fnmatch import fnmatch
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable

import asyncssh
Expand Down Expand Up @@ -254,6 +255,51 @@ def delete_file(self, path: str) -> None:
conn = self.get_conn()
conn.remove(path)

def retrieve_directory(self, remote_full_path: str, local_full_path: str, prefetch: bool = True) -> None:
"""
Transfer the remote directory to a local location.
If local_full_path is a string path, the directory will be put
at that location.
:param remote_full_path: full path to the remote directory
:param local_full_path: full path to the local directory
:param prefetch: controls whether prefetch is performed (default: True)
"""
if Path(local_full_path).exists():
raise AirflowException(f"{local_full_path} already exists")
Path(local_full_path).mkdir(parents=True)
files, dirs, _ = self.get_tree_map(remote_full_path)
for dir_path in dirs:
new_local_path = os.path.join(local_full_path, os.path.relpath(dir_path, remote_full_path))
Path(new_local_path).mkdir(parents=True, exist_ok=True)
for file_path in files:
new_local_path = os.path.join(local_full_path, os.path.relpath(file_path, remote_full_path))
self.retrieve_file(file_path, new_local_path, prefetch)

def store_directory(self, remote_full_path: str, local_full_path: str, confirm: bool = True) -> None:
"""
Transfer a local directory to the remote location.
If local_full_path is a string path, the directory will be read
from that location.
:param remote_full_path: full path to the remote directory
:param local_full_path: full path to the local directory
"""
if self.path_exists(remote_full_path):
raise AirflowException(f"{remote_full_path} already exists")
self.create_directory(remote_full_path)
for root, dirs, files in os.walk(local_full_path):
for dir_name in dirs:
dir_path = os.path.join(root, dir_name)
new_remote_path = os.path.join(remote_full_path, os.path.relpath(dir_path, local_full_path))
self.create_directory(new_remote_path)
for file_name in files:
file_path = os.path.join(root, file_name)
new_remote_path = os.path.join(remote_full_path, os.path.relpath(file_path, local_full_path))
self.store_file(new_remote_path, file_path, confirm)

def get_mod_time(self, path: str) -> str:
"""
Get an entry's modification time.
Expand Down
12 changes: 10 additions & 2 deletions providers/src/airflow/providers/sftp/operators/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,22 @@ def execute(self, context: Any) -> str | list[str] | None:
Path(local_folder).mkdir(parents=True, exist_ok=True)
file_msg = f"from {_remote_filepath} to {_local_filepath}"
self.log.info("Starting to transfer %s", file_msg)
self.sftp_hook.retrieve_file(_remote_filepath, _local_filepath)
if self.sftp_hook.isdir(_remote_filepath):
self.sftp_hook.retrieve_directory(_remote_filepath, _local_filepath)
else:
self.sftp_hook.retrieve_file(_remote_filepath, _local_filepath)
else:
remote_folder = os.path.dirname(_remote_filepath)
if self.create_intermediate_dirs:
self.sftp_hook.create_directory(remote_folder)
file_msg = f"from {_local_filepath} to {_remote_filepath}"
self.log.info("Starting to transfer file %s", file_msg)
self.sftp_hook.store_file(_remote_filepath, _local_filepath, confirm=self.confirm)
if os.path.isdir(_local_filepath):
self.sftp_hook.store_directory(
_remote_filepath, _local_filepath, confirm=self.confirm
)
else:
self.sftp_hook.store_file(_remote_filepath, _local_filepath, confirm=self.confirm)

except Exception as e:
raise AirflowException(f"Error while transferring {file_msg}, error: {e}")
Expand Down
17 changes: 17 additions & 0 deletions providers/tests/sftp/hooks/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,23 @@ def test_get_matched_files_with_different_pattern(self):
output = self.hook.get_files_by_pattern(self.temp_dir, "*_file_*.txt")
assert output == [ANOTHER_FILE_FOR_TESTS]

def test_store_and_retrieve_directory(self):
stored_dir_name = "stored_dir"
self.hook.store_directory(
remote_full_path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, stored_dir_name),
local_full_path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, SUB_DIR),
)
output = self.hook.list_directory(
path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, stored_dir_name)
)
assert output == [TMP_FILE_FOR_TESTS]
retrieved_dir_name = "retrieved_dir"
self.hook.retrieve_directory(
remote_full_path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, stored_dir_name),
local_full_path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS, retrieved_dir_name),
)
assert retrieved_dir_name in os.listdir(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS))


class MockSFTPClient:
def __init__(self):
Expand Down
30 changes: 30 additions & 0 deletions providers/tests/sftp/operators/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,21 @@ def test_multiple_paths_get(self, mock_get):
assert args0 == (remote_filepath[0], local_filepath[0])
assert args1 == (remote_filepath[1], local_filepath[1])

@mock.patch("airflow.providers.sftp.operators.sftp.SFTPHook.retrieve_directory")
def test_str_dirpaths_get(self, mock_get):
local_dirpath = "/tmp_local"
remote_dirpath = "/tmp"
SFTPOperator(
task_id="test_str_to_list",
sftp_hook=self.sftp_hook,
local_filepath=local_dirpath,
remote_filepath=remote_dirpath,
operation=SFTPOperation.GET,
).execute(None)
assert mock_get.call_count == 1
args, _ = mock_get.call_args_list[0]
assert args == (remote_dirpath, local_dirpath)

@mock.patch("airflow.providers.sftp.operators.sftp.SFTPHook.store_file")
def test_str_filepaths_put(self, mock_get):
local_filepath = "/tmp/test"
Expand Down Expand Up @@ -443,6 +458,21 @@ def test_multiple_paths_put(self, mock_put):
assert args0 == (remote_filepath[0], local_filepath[0])
assert args1 == (remote_filepath[1], local_filepath[1])

@mock.patch("airflow.providers.sftp.operators.sftp.SFTPHook.store_directory")
def test_str_dirpaths_put(self, mock_get):
local_dirpath = "/tmp"
remote_dirpath = "/tmp_remote"
SFTPOperator(
task_id="test_str_dirpaths_put",
sftp_hook=self.sftp_hook,
local_filepath=local_dirpath,
remote_filepath=remote_dirpath,
operation=SFTPOperation.PUT,
).execute(None)
assert mock_get.call_count == 1
args, _ = mock_get.call_args_list[0]
assert args == (remote_dirpath, local_dirpath)

@mock.patch("airflow.providers.sftp.operators.sftp.SFTPHook.retrieve_file")
def test_return_str_when_local_filepath_was_str(self, mock_get):
local_filepath = "/tmp/ltest1"
Expand Down

0 comments on commit 51fea3e

Please sign in to comment.