diff --git a/s3torchconnector/pyproject.toml b/s3torchconnector/pyproject.toml index 1f96c0c0..f642e6f7 100644 --- a/s3torchconnector/pyproject.toml +++ b/s3torchconnector/pyproject.toml @@ -24,6 +24,7 @@ classifiers = [ dependencies = [ "torch >= 2.0.1, != 2.5.0", "s3torchconnectorclient >= 1.3.0", + "pathlib_abc >= 0.3.1" ] [project.optional-dependencies] diff --git a/s3torchconnector/src/s3torchconnector/__init__.py b/s3torchconnector/src/s3torchconnector/__init__.py index d46ecb02..da20f304 100644 --- a/s3torchconnector/src/s3torchconnector/__init__.py +++ b/s3torchconnector/src/s3torchconnector/__init__.py @@ -10,6 +10,7 @@ from .s3iterable_dataset import S3IterableDataset from .s3map_dataset import S3MapDataset from .s3checkpoint import S3Checkpoint +from .s3path import S3Path from ._version import __version__ from ._s3client import S3ClientConfig @@ -21,5 +22,6 @@ "S3Writer", "S3Exception", "S3ClientConfig", + "S3Path", "__version__", ] diff --git a/s3torchconnector/src/s3torchconnector/s3path.py b/s3torchconnector/src/s3torchconnector/s3path.py new file mode 100644 index 00000000..3c4b8720 --- /dev/null +++ b/s3torchconnector/src/s3torchconnector/s3path.py @@ -0,0 +1,305 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# // SPDX-License-Identifier: BSD +import errno +import io +import logging +import os +import posixpath +import stat +import time +from types import SimpleNamespace +from typing import Optional + +from pathlib import PurePosixPath +from pathlib_abc import ParserBase, PathBase, UnsupportedOperation +from urllib.parse import urlparse + +from s3torchconnectorclient._mountpoint_s3_client import S3Exception +from ._s3client import S3Client, S3ClientConfig + +logger = logging.getLogger(__name__) + +ENV_S3_TORCH_CONNECTOR_REGION = "S3_TORCH_CONNECTOR_REGION" +ENV_S3_TORCH_CONNECTOR_THROUGHPUT_TARGET_GPBS = ( + "S3_TORCH_CONNECTOR_THROUGHPUT_TARGET_GPBS" +) +ENV_S3_TORCH_CONNECTOR_PART_SIZE_MB = "S3_TORCH_CONNECTOR_PART_SIZE_MB" +DRIVE = "s3://" + + +def _get_default_bucket_region(): + for var in [ + ENV_S3_TORCH_CONNECTOR_REGION, + "AWS_DEFAULT_REGION", + "AWS_REGION", + "REGION", + ]: + if var in os.environ: + return os.environ[var] + + +def _get_default_throughput_target_gbps(): + if ENV_S3_TORCH_CONNECTOR_THROUGHPUT_TARGET_GPBS in os.environ: + return float(os.environ[ENV_S3_TORCH_CONNECTOR_THROUGHPUT_TARGET_GPBS]) + + +def _get_default_part_size(): + if ENV_S3_TORCH_CONNECTOR_PART_SIZE_MB in os.environ: + return int(os.environ[ENV_S3_TORCH_CONNECTOR_PART_SIZE_MB]) * 1024 * 1024 + + +class S3Parser(ParserBase): + @classmethod + def _unsupported_msg(cls, attribute): + return f"{cls.__name__}.{attribute} is unsupported" + + @property + def sep(self): + return "/" + + def join(self, path, *paths): + return posixpath.join(path, *paths) + + def split(self, path): + scheme, bucket, prefix, _, _, _ = urlparse(path) + parent, _, name = prefix.lstrip("/").rpartition("/") + if not bucket: + return bucket, name + return (scheme + "://" + bucket + "/" + parent, name) + + def splitdrive(self, path): + scheme, bucket, prefix, _, _, _ = urlparse(path) + drive = f"{scheme}://{bucket}" + return drive, prefix.lstrip("/") + + def splitext(self, path): + return posixpath.splitext(path) + + def normcase(self, path): + return posixpath.normcase(path) + + def isabs(self, path): + s = os.fspath(path) + scheme_tail = s.split("://", 1) + return len(scheme_tail) == 2 + + +class S3Path(PathBase): + __slots__ = ("_region", "_s3_client_config", "_client", "_raw_path") + parser = S3Parser() + _stat_cache_ttl_seconds = 1 + _stat_cache_size = 1024 + _stat_cache = {} + + def __init__( + self, + *pathsegments, + client: Optional[S3Client] = None, + region=None, + s3_client_config=None, + ): + super().__init__(*pathsegments) + if not self.drive.startswith(DRIVE): + raise ValueError("Should pass in S3 uri") + self._region = region or _get_default_bucket_region() + self._s3_client_config = s3_client_config or S3ClientConfig( + throughput_target_gbps=_get_default_throughput_target_gbps(), + part_size=_get_default_part_size(), + ) + self._client = client or S3Client( + region=self._region, + s3client_config=self._s3_client_config, + ) + + def __repr__(self): + return f"{type(self).__name__}({str(self)!r})" + + def __hash__(self): + return hash(str(self)) + + def __eq__(self, other): + if not isinstance(other, S3Path): + return NotImplemented + return str(self) == str(other) + + def with_segments(self, *pathsegments): + path = str("/".join(pathsegments)).lstrip("/") + if not path.startswith(self.anchor): + path = f"{self.anchor}{path}" + return type(self)( + path, + client=self._client, + region=self._region, + s3_client_config=self._s3_client_config, + ) + + @property + def bucket(self): + if self.is_absolute() and self.drive.startswith(DRIVE): + return self.drive[5:] + return "" + + @property + def key(self): + if self.is_absolute() and len(self.parts) > 1: + return self.parser.sep.join(self.parts[1:]) + return "" + + def open(self, mode="r", buffering=-1, encoding=None, errors=None, newline=None): + if buffering != -1: + raise ValueError("Only default buffering (-1) is supported.") + if not self.is_absolute(): + raise ValueError("S3Path must be absolute.") + action = "".join(c for c in mode if c not in "btU") + if action == "r": + try: + fileobj = self._client.get_object(self.bucket, self.key) + except S3Exception: + raise FileNotFoundError(errno.ENOENT, "Not found", str(self)) from None + except: + raise + elif action == "w": + try: + fileobj = self._client.put_object(self.bucket, self.key) + except S3Exception: + raise + except: + raise + else: + raise UnsupportedOperation() + if "b" not in mode: + fileobj = io.TextIOWrapper(fileobj, encoding, errors, newline) + return fileobj + + def stat(self, *, follow_symlinks=True): + cache_key = (self.bucket, self.key.rstrip("/")) + cached_result = self._stat_cache.get(cache_key) + if cached_result: + result, timestamp = cached_result + if time.time() - timestamp < self._stat_cache_ttl_seconds: + return result + del self._stat_cache[cache_key] + try: + info = self._client.head_object(self.bucket, self.key.rstrip("/")) + mode = stat.S_IFREG + except S3Exception as e: + listobj = next(self._list_objects(max_keys=2)) + + if len(listobj.object_info) > 0 or len(listobj.common_prefixes) > 0: + info = SimpleNamespace(size=0, last_modified=None) + mode = stat.S_IFDIR + else: + error_msg = f"No stats available for {self}; it may not exist." + raise FileNotFoundError(error_msg) from e + + result = os.stat_result( + ( + mode, # mode + None, # ino + DRIVE, # dev + None, # nlink + None, # uid + None, # gid + info.size, # size + None, # atime + info.last_modified or 0, # mtime + None, # ctime + ) + ) + if len(self._stat_cache) >= self._stat_cache_size: + self._stat_cache.pop(next(iter(self._stat_cache))) + + self._stat_cache[cache_key] = (result, time.time()) + return result + + def iterdir(self): + if not self.is_dir(): + raise NotADirectoryError("not a s3 folder") + key = "" if not self.key else self.key.rstrip("/") + "/" + for page in self._list_objects(): + for prefix in page.common_prefixes: + # yield directories first + yield self.with_segments(prefix.rstrip("/")) + for info in page.object_info: + if info.key != key: + yield self.with_segments(info.key) + + def mkdir(self, mode=0o777, parents=False, exist_ok=False): + if self.is_dir(): + if exist_ok: + return + raise FileExistsError(f"S3 folder {self} already exists.") + with self._client.put_object(self.bucket, self.key.rstrip("/") + "/"): + pass + + def unlink(self, missing_ok=False): + if self.is_dir(): + if missing_ok: + return + raise IsADirectoryError( + f"Path {self} is a directory; call rmdir instead of unlink." + ) + self._client.delete_object(self.bucket, self.key) + + def rmdir(self): + if not self.is_dir(): + raise NotADirectoryError(f"{self} is not an s3 folder") + listobj = next(self._list_objects(max_keys=2)) + if len(listobj.object_info) > 1: + raise Exception(f"{self} is not empty") + self._client.delete_object(self.bucket, self.key.rstrip("/") + "/") + + def glob(self, pattern, *, case_sensitive=None, recurse_symlinks=True): + if ".." in pattern: + raise NotImplementedError( + "Relative paths with '..' not supported in glob patterns" + ) + if pattern.startswith(self.anchor) or pattern.startswith("/"): + raise NotImplementedError("Non-relative patterns are unsupported") + + parts = list(PurePosixPath(pattern).parts) + select = self._glob_selector(parts, case_sensitive, recurse_symlinks) + return select(self) + + def with_name(self, name): + """Return a new path with the file name changed.""" + split = self.parser.split + if split(name)[0]: + # Ensure that the provided name does not contain any path separators + raise ValueError(f"Invalid name {name!r}") + return self.with_segments(str(self.parent), name) + + def _list_objects(self, max_keys: int = 1000): + try: + key = "" if not self.key else self.key.rstrip("/") + "/" + pages = iter( + self._client.list_objects( + self.bucket, key, delimiter="/", max_keys=max_keys + ) + ) + for page in pages: + yield page + except S3Exception as e: + raise RuntimeError(f"Failed to list contents of {self}") from e + + def __getstate__(self): + state = { + slot: getattr(self, slot, None) + for cls in self.__class__.__mro__ + for slot in getattr(cls, "__slots__", []) + if slot + not in [ + "_client", + ] + } + return (None, state) + + def __setstate__(self, state): + _, state_dict = state + for slot, value in state_dict.items(): + if slot not in ["_client"]: + setattr(self, slot, value) + self._client = S3Client( + region=self._region, + s3client_config=self._s3_client_config, + ) diff --git a/s3torchconnector/tst/unit/test_s3path.py b/s3torchconnector/tst/unit/test_s3path.py new file mode 100644 index 00000000..faf1cfeb --- /dev/null +++ b/s3torchconnector/tst/unit/test_s3path.py @@ -0,0 +1,250 @@ +import collections +import io +import time +import pytest + +from pathlib_abc import PathBase + +from s3torchconnector import S3Path +from s3torchconnector._s3client._s3client import S3Client +from s3torchconnector._s3client._mock_s3client import MockS3Client + + +def s3_uri(bucket, key=None): + return f"s3://{bucket}" if key is None else f"s3://{bucket}/{key}" + + +TEST_BUCKET = "test-bucket" +TEST_KEY = "test-key" +TEST_REGION = "us-east-1" +TEST_S3_URI = s3_uri(TEST_BUCKET, TEST_KEY) +MISSING_S3_URI = s3_uri(TEST_BUCKET, "foo") + + +@pytest.fixture +def s3_client() -> S3Client: + client = MockS3Client(TEST_REGION, TEST_BUCKET) + return client + + +@pytest.fixture +def s3_bucket_path(s3_client) -> S3Path: + s3_bucket_path = S3Path(s3_uri(TEST_BUCKET)) + s3_bucket_path._client = s3_client + return s3_bucket_path + + +@pytest.fixture +def s3_path(s3_bucket_path) -> S3Path: + s3_path = s3_bucket_path / TEST_KEY + s3_path._client.add_object(TEST_KEY, b"this is an s3 file\n") + return s3_path + + +@pytest.fixture +def missing_s3_path(s3_bucket_path) -> S3Path: + return s3_bucket_path / MISSING_S3_URI + + +def test_s3path_subclass_path(s3_path: S3Path): + assert issubclass(S3Path, PathBase) + assert isinstance(s3_path, PathBase) + + +def test_s3path_creation(s3_path: S3Path): + assert s3_path + assert s3_path.bucket == TEST_BUCKET + assert s3_path.key == TEST_KEY + + +@pytest.mark.parametrize( + "path", + [(""), (TEST_KEY)], +) +def test_s3path_invalid_creation(path): + with pytest.raises(ValueError, match="Should pass in S3 uri"): + S3Path(path) + + +def test_s3path_samefile(s3_path, missing_s3_path): + assert s3_path.samefile(TEST_S3_URI) + + with pytest.raises(FileNotFoundError): + s3_path.samefile(MISSING_S3_URI) + with pytest.raises(FileNotFoundError): + s3_path.samefile(missing_s3_path) + with pytest.raises(FileNotFoundError): + missing_s3_path.samefile(TEST_S3_URI) + with pytest.raises(FileNotFoundError): + missing_s3_path.samefile(s3_path) + + +def test_s3path_exists(s3_path, s3_bucket_path, missing_s3_path): + assert s3_path.exists() is True + assert s3_bucket_path.exists() is True + assert missing_s3_path.exists() is False + + +def test_s3path_open(s3_path): + with s3_path.open("r") as reader: + assert isinstance(reader, io.TextIOBase) + assert reader.read() == "this is an s3 file\n" + with s3_path.open("rb") as reader: + assert isinstance(reader, io.BufferedIOBase) + assert reader.read().strip() == b"this is an s3 file" + + +def test_s3path_read_write_bytes(s3_path): + (s3_path / "fileA").write_bytes(b"abcd") + assert (s3_path / "fileA").read_bytes() == b"abcd" + + with pytest.raises(TypeError): + (s3_path / "fileA").write_bytes("somestr") + assert (s3_path / "fileA").read_bytes() == b"abcd" + + +def test_s3path_read_write_text(s3_path): + (s3_path / "fileA").write_text("äbcd", encoding="latin-1") + assert (s3_path / "fileA").read_text(encoding="utf-8", errors="ignore") == "bcd" + + with pytest.raises(TypeError): + (s3_path / "fileA").write_text(b"somebytes") + assert (s3_path / "fileA").read_text(encoding="latin-1") == "äbcd" + + +def test_s3path_iterdir(s3_path, s3_bucket_path): + s3_bucket_path._client.add_object("file1.txt", b"file 1 content") + s3_bucket_path._client.add_object("file2.txt", b"file 2 content") + s3_bucket_path._client.add_object("dir1/file3.txt", b"nested file") + (s3_bucket_path / "dir1" / "nested_dir").mkdir() + (s3_bucket_path / "dir2").mkdir() + + bucket_contents = list(s3_bucket_path.iterdir()) + assert bucket_contents == [ + s3_bucket_path / "dir1", + s3_bucket_path / "dir2", + s3_bucket_path / "file1.txt", + s3_bucket_path / "file2.txt", + s3_path, + ] + + dir1_path = s3_bucket_path / "dir1" + dir1_contents = list(dir1_path.iterdir()) + assert dir1_contents == [dir1_path / "nested_dir", dir1_path / "file3.txt"] + + +def test_s3path_nondir(s3_path): + with pytest.raises(NotADirectoryError): + # does not follow pathlib in python 3.13+, which raises immediately before iterating + next(s3_path.iterdir()) + + +def test_s3path_glob(s3_path, s3_bucket_path): + it = s3_bucket_path.glob(TEST_KEY) + assert isinstance(it, collections.abc.Iterator) + assert set(it) == {s3_path} + + +def test_s3path_glob_empty_pattern(s3_path): + # no relative paths in s3 + assert list(s3_path.glob("")) == [s3_path] + assert list(s3_path.glob(".")) == [s3_path] + assert list(s3_path.glob("./")) == [s3_path] + + +def test_s3path_stat(s3_path, s3_bucket_path): + stat_file = s3_path.stat() + stat_folder = s3_bucket_path.stat() + + # no concept of directories in s3, existing folders count as "directories" + assert isinstance(stat_file.st_mode, int) + assert stat_file.st_mode != stat_folder.st_mode + + assert stat_file.st_dev == "s3://" + assert stat_file.st_dev == stat_folder.st_dev + + +def test_s3path_isdir(s3_path, s3_bucket_path, missing_s3_path): + assert not s3_path.is_dir() + assert s3_bucket_path.is_dir() + # even though directories don't matter in s3, count missing prefixes as non directories + assert not missing_s3_path.is_dir() + + +def test_s3path_withname(s3_path): + new_name = "new_file.txt" + new_path = s3_path.with_name(new_name) + assert new_path.key == new_name, f"Expected {new_name}, got {new_path.key}" + + s3_path_with_slash = s3_path.with_segments("folder", "old_file.txt") + new_path_with_slash = s3_path_with_slash.with_name(new_name) + assert ( + new_path_with_slash.key == "folder/new_file.txt" + ), f"Expected folder/new_file.txt, got {new_path_with_slash.key}" + + try: + s3_path.with_name("invalid/name.txt") + except ValueError as e: + assert ( + str(e) == "Invalid name 'invalid/name.txt'" + ), f"Unexpected error message: {e}" + + +def test_s3path_rmdir(s3_path): + empty_folder = s3_path / "empty" + empty_folder.mkdir(parents=True, exist_ok=True) + empty_folder.rmdir() + with pytest.raises(NotADirectoryError, match=f"{empty_folder} is not an s3 folder"): + time.sleep(1) # S3 needs some time to register the deletion + empty_folder.rmdir() + + nonempty_folder = s3_path / "nonempty" + nonempty_folder.mkdir(parents=True, exist_ok=True) + nonempty_folder._client.add_object("test-key/nonempty/file.txt", b"file") + with pytest.raises(Exception, match=f"{nonempty_folder} is not empty"): + nonempty_folder.rmdir() + + nonexistent_folder = s3_path / "nonexistent_folder" + with pytest.raises( + NotADirectoryError, match=f"{nonexistent_folder} is not an s3 folder" + ): + nonexistent_folder.rmdir() + + +def test_s3path_unlink(s3_path): + file = s3_path / "test_file.txt" + s3_path._client.add_object("test-key/test_file.txt", b"") + assert file.exists() + file.unlink() + time.sleep(1) # S3 needs some time to register the deletion + assert not file.exists() + + directory = s3_path / "some_directory" + directory.mkdir(parents=True, exist_ok=True) + with pytest.raises(IsADirectoryError): + directory.unlink() + + nonexistent_file = s3_path / "nonexistent_file.txt" + nonexistent_file.unlink() # no op + assert not nonexistent_file.exists() + + +def test_s3path_mkdir(s3_path): + test_dir = s3_path / "test_dir" + test_dir.mkdir(parents=True, exist_ok=False) + + assert test_dir.exists() + assert test_dir.is_dir() + + with pytest.raises(FileExistsError, match=f"{test_dir} already exists"): + test_dir.mkdir(parents=True, exist_ok=False) + + test_dir.mkdir(parents=True, exist_ok=True) + assert test_dir.exists() + + parent_dir = s3_path / "parent_dir" + sub_dir = parent_dir / "sub_dir" + sub_dir.mkdir(parents=True, exist_ok=False) + + assert parent_dir.exists() + assert sub_dir.exists() \ No newline at end of file