Skip to content

Commit dd79133

Browse files
author
Ryan Li
committed
add s3path
1 parent 2d95165 commit dd79133

File tree

4 files changed

+471
-0
lines changed

4 files changed

+471
-0
lines changed

s3torchconnector/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ classifiers = [
2424
dependencies = [
2525
"torch >= 2.0.1, != 2.5.0",
2626
"s3torchconnectorclient >= 1.3.0",
27+
"pathlib_abc >= 0.3.1"
2728
]
2829

2930
[project.optional-dependencies]

s3torchconnector/src/s3torchconnector/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .s3iterable_dataset import S3IterableDataset
1111
from .s3map_dataset import S3MapDataset
1212
from .s3checkpoint import S3Checkpoint
13+
from .s3path import S3Path
1314
from ._version import __version__
1415
from ._s3client import S3ClientConfig
1516

@@ -21,5 +22,6 @@
2122
"S3Writer",
2223
"S3Exception",
2324
"S3ClientConfig",
25+
"S3Path",
2426
"__version__",
2527
]
Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
3+
import errno
4+
import io
5+
import logging
6+
import os
7+
import posixpath
8+
import stat
9+
import time
10+
from types import SimpleNamespace
11+
from typing import Optional
12+
13+
from pathlib import PurePosixPath
14+
from pathlib_abc import ParserBase, PathBase, UnsupportedOperation
15+
from urllib.parse import urlparse
16+
17+
from s3torchconnectorclient._mountpoint_s3_client import S3Exception
18+
from ._s3client import S3Client, S3ClientConfig
19+
20+
logger = logging.getLogger(__name__)
21+
22+
ENV_S3_TORCH_CONNECTOR_REGION = "S3_TORCH_CONNECTOR_REGION"
23+
ENV_S3_TORCH_CONNECTOR_THROUGHPUT_TARGET_GPBS = (
24+
"S3_TORCH_CONNECTOR_THROUGHPUT_TARGET_GPBS"
25+
)
26+
ENV_S3_TORCH_CONNECTOR_PART_SIZE_MB = "S3_TORCH_CONNECTOR_PART_SIZE_MB"
27+
DRIVE = "s3://"
28+
29+
30+
def _get_default_bucket_region():
31+
for var in [
32+
ENV_S3_TORCH_CONNECTOR_REGION,
33+
"AWS_DEFAULT_REGION",
34+
"AWS_REGION",
35+
"REGION",
36+
]:
37+
if var in os.environ:
38+
return os.environ[var]
39+
40+
41+
def _get_default_throughput_target_gbps():
42+
if ENV_S3_TORCH_CONNECTOR_THROUGHPUT_TARGET_GPBS in os.environ:
43+
return float(os.environ[ENV_S3_TORCH_CONNECTOR_THROUGHPUT_TARGET_GPBS])
44+
45+
46+
def _get_default_part_size():
47+
if ENV_S3_TORCH_CONNECTOR_PART_SIZE_MB in os.environ:
48+
return int(os.environ[ENV_S3_TORCH_CONNECTOR_PART_SIZE_MB]) * 1024 * 1024
49+
50+
51+
class S3Parser(ParserBase):
52+
@classmethod
53+
def _unsupported_msg(cls, attribute):
54+
return f"{cls.__name__}.{attribute} is unsupported"
55+
56+
@property
57+
def sep(self):
58+
return "/"
59+
60+
def join(self, path, *paths):
61+
return posixpath.join(path, *paths)
62+
63+
def split(self, path):
64+
scheme, bucket, prefix, _, _, _ = urlparse(path)
65+
parent, _, name = prefix.lstrip("/").rpartition("/")
66+
if not bucket:
67+
return bucket, name
68+
return (scheme + "://" + bucket + "/" + parent, name)
69+
70+
def splitdrive(self, path):
71+
scheme, bucket, prefix, _, _, _ = urlparse(path)
72+
drive = f"{scheme}://{bucket}"
73+
return drive, prefix.lstrip("/")
74+
75+
def splitext(self, path):
76+
return posixpath.splitext(path)
77+
78+
def normcase(self, path):
79+
return posixpath.normcase(path)
80+
81+
def isabs(self, path):
82+
s = os.fspath(path)
83+
scheme_tail = s.split("://", 1)
84+
return len(scheme_tail) == 2
85+
86+
87+
class S3Path(PathBase):
88+
__slots__ = ("_region", "_s3_client_config", "_client", "_raw_path")
89+
parser = S3Parser()
90+
_stat_cache_ttl_seconds = 1
91+
_stat_cache_size = 1024
92+
_stat_cache = {}
93+
94+
def __init__(
95+
self,
96+
*pathsegments,
97+
client: Optional[S3Client] = None,
98+
region=None,
99+
s3_client_config=None,
100+
):
101+
super().__init__(*pathsegments)
102+
if not self.drive.startswith(DRIVE):
103+
raise ValueError("Should pass in S3 uri")
104+
self._region = region or _get_default_bucket_region()
105+
self._s3_client_config = s3_client_config or S3ClientConfig(
106+
throughput_target_gbps=_get_default_throughput_target_gbps(),
107+
part_size=_get_default_part_size(),
108+
)
109+
self._client = client or S3Client(
110+
region=self._region,
111+
s3client_config=self._s3_client_config,
112+
)
113+
114+
def __repr__(self):
115+
return f"{type(self).__name__}({str(self)!r})"
116+
117+
def __hash__(self):
118+
return hash(str(self))
119+
120+
def __eq__(self, other):
121+
if not isinstance(other, S3Path):
122+
return NotImplemented
123+
return str(self) == str(other)
124+
125+
def with_segments(self, *pathsegments):
126+
path = str("/".join(pathsegments)).lstrip("/")
127+
if not path.startswith(self.anchor):
128+
path = f"{self.anchor}{path}"
129+
return type(self)(
130+
path,
131+
client=self._client,
132+
region=self._region,
133+
s3_client_config=self._s3_client_config,
134+
)
135+
136+
@property
137+
def bucket(self):
138+
if self.is_absolute() and self.drive.startswith(DRIVE):
139+
return self.drive[5:]
140+
return ""
141+
142+
@property
143+
def key(self):
144+
if self.is_absolute() and len(self.parts) > 1:
145+
return self.parser.sep.join(self.parts[1:])
146+
return ""
147+
148+
def open(self, mode="r", buffering=-1, encoding=None, errors=None, newline=None):
149+
if buffering != -1:
150+
raise ValueError("Only default buffering (-1) is supported.")
151+
if not self.is_absolute():
152+
raise ValueError("S3Path must be absolute.")
153+
action = "".join(c for c in mode if c not in "btU")
154+
if action == "r":
155+
try:
156+
fileobj = self._client.get_object(self.bucket, self.key)
157+
except S3Exception:
158+
raise FileNotFoundError(errno.ENOENT, "Not found", str(self)) from None
159+
except:
160+
raise
161+
elif action == "w":
162+
try:
163+
fileobj = self._client.put_object(self.bucket, self.key)
164+
except S3Exception:
165+
raise
166+
except:
167+
raise
168+
else:
169+
raise UnsupportedOperation()
170+
if "b" not in mode:
171+
fileobj = io.TextIOWrapper(fileobj, encoding, errors, newline)
172+
return fileobj
173+
174+
def stat(self, *, follow_symlinks=True):
175+
cache_key = (self.bucket, self.key.rstrip("/"))
176+
cached_result = self._stat_cache.get(cache_key)
177+
if cached_result:
178+
result, timestamp = cached_result
179+
if time.time() - timestamp < self._stat_cache_ttl_seconds:
180+
return result
181+
del self._stat_cache[cache_key] # Invalidate expired entry
182+
183+
try:
184+
info = self._client.head_object(self.bucket, self.key.rstrip("/"))
185+
mode = stat.S_IFREG
186+
except S3Exception as e:
187+
listobj = next(self._list_objects(max_keys=2))
188+
189+
if len(listobj.object_info) > 0:
190+
info = SimpleNamespace(size=0, last_modified=None)
191+
mode = stat.S_IFDIR
192+
else:
193+
error_msg = f"No stats available for {self}; it may not exist."
194+
raise FileNotFoundError(error_msg) from e
195+
196+
result = os.stat_result(
197+
(
198+
mode, # mode
199+
None, # ino
200+
DRIVE, # dev
201+
None, # nlink
202+
None, # uid
203+
None, # gid
204+
info.size, # size
205+
None, # atime
206+
info.last_modified or 0, # mtime
207+
None, # ctime
208+
)
209+
)
210+
if len(self._stat_cache) >= self._stat_cache_size:
211+
self._stat_cache.pop(next(iter(self._stat_cache)))
212+
213+
self._stat_cache[cache_key] = (result, time.time())
214+
return result
215+
216+
def iterdir(self):
217+
if not self.is_dir():
218+
raise NotADirectoryError("not a s3 folder")
219+
key = "" if not self.key else self.key.rstrip("/") + "/"
220+
for page in self._list_objects():
221+
for prefix in page.common_prefixes:
222+
# yield directories first
223+
yield self.with_segments(prefix.rstrip("/"))
224+
for info in page.object_info:
225+
if info.key != key:
226+
yield self.with_segments(info.key)
227+
228+
def mkdir(self, mode=0o777, parents=False, exist_ok=False):
229+
if self.is_dir():
230+
if exist_ok:
231+
return
232+
raise FileExistsError(f"S3 folder {self} already exists.")
233+
writer = None
234+
try:
235+
writer = self._client.put_object(
236+
self.bucket, self.key if self.key.endswith("/") else self.key + "/"
237+
)
238+
finally:
239+
if writer is not None:
240+
writer.close()
241+
242+
def unlink(self, missing_ok=False):
243+
if self.is_dir():
244+
if missing_ok:
245+
return
246+
raise Exception(
247+
f"Path {self} is a directory; call rmdir instead of unlink."
248+
)
249+
self._client.delete_object(self.bucket, self.key)
250+
251+
def rmdir(self):
252+
if not self.is_dir():
253+
raise NotADirectoryError(f"Path {self} is not an s3 folder")
254+
try:
255+
next(self._list_objects(max_keys=2))
256+
raise Exception(f"Path {self} is not empty")
257+
except NotADirectoryError:
258+
self._client.delete_object(self.bucket, self.key)
259+
260+
def glob(self, pattern, *, case_sensitive=None, recurse_symlinks=True):
261+
if ".." in pattern:
262+
raise NotImplementedError(
263+
"Relative paths with '..' not supported in glob patterns"
264+
)
265+
if pattern.startswith(self.anchor) or pattern.startswith("/"):
266+
raise NotImplementedError("Non-relative patterns are unsupported")
267+
268+
parts = list(PurePosixPath(pattern).parts)
269+
select = self._glob_selector(parts, case_sensitive, recurse_symlinks)
270+
return select(self)
271+
272+
def with_name(self, name):
273+
"""Return a new path with the file name changed."""
274+
split = self.parser.split
275+
if split(name)[0]:
276+
# Ensure that the provided name does not contain any path separators
277+
raise ValueError(f"Invalid name {name!r}")
278+
return self.with_segments(self.parent, name)
279+
280+
def _list_objects(self, max_keys: int = 1000):
281+
try:
282+
key = "" if not self.key else self.key.rstrip("/") + "/"
283+
pages = iter(
284+
self._client.list_objects(
285+
self.bucket, key, delimiter="/", max_keys=max_keys
286+
)
287+
)
288+
for page in pages:
289+
yield page
290+
except S3Exception:
291+
raise
292+
293+
def __getstate__(self):
294+
state = {
295+
slot: getattr(self, slot, None)
296+
for cls in self.__class__.__mro__
297+
for slot in getattr(cls, "__slots__", [])
298+
if slot
299+
not in [
300+
"_client",
301+
]
302+
}
303+
return (None, state)
304+
305+
def __setstate__(self, state):
306+
_, state_dict = state
307+
for slot, value in state_dict.items():
308+
if slot not in ["_client"]:
309+
setattr(self, slot, value)
310+
self._client = S3Client(
311+
region=self._region,
312+
s3client_config=self._s3_client_config,
313+
)

0 commit comments

Comments
 (0)