Skip to content

Commit 07df305

Browse files
authored
Correctly pass filename to file creation call (#343)
Resolves #338 Follow up to #226 Filenames from open file handles were not being passed correctly in `create` / `async_create`. This PR fixes that and adds some test coverage. --------- Signed-off-by: Mattt Zmuda <[email protected]>
1 parent d4df7f6 commit 07df305

File tree

3 files changed

+106
-24
lines changed

3 files changed

+106
-24
lines changed

.vscode/settings.json

+3-22
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,12 @@
11
{
22
"editor.formatOnSave": true,
3-
"editor.formatOnType": true,
4-
"editor.formatOnPaste": true,
5-
"editor.renderControlCharacters": true,
6-
"editor.suggest.localityBonus": true,
73
"files.insertFinalNewline": true,
84
"files.trimFinalNewlines": true,
95
"[python]": {
106
"editor.defaultFormatter": "charliermarsh.ruff",
11-
"editor.formatOnSave": true,
127
"editor.codeActionsOnSave": {
13-
"source.fixAll": "explicit",
14-
"source.organizeImports": "explicit"
8+
"source.fixAll.ruff": "explicit",
9+
"source.organizeImports.ruff": "explicit"
1510
}
16-
},
17-
"python.languageServer": "Pylance",
18-
"python.analysis.typeCheckingMode": "basic",
19-
"python.testing.pytestArgs": [
20-
"-vvv",
21-
"python"
22-
],
23-
"python.testing.unittestEnabled": false,
24-
"python.testing.pytestEnabled": true,
25-
"ruff.lint.args": [
26-
"--config=pyproject.toml"
27-
],
28-
"ruff.format.args": [
29-
"--config=pyproject.toml"
30-
],
11+
}
3112
}

replicate/file.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ def create(
7171
"""
7272

7373
if isinstance(file, (str, pathlib.Path)):
74+
file_path = pathlib.Path(file)
75+
params["filename"] = file_path.name
7476
with open(file, "rb") as f:
7577
return self.create(f, **params)
7678
elif not isinstance(file, (io.IOBase, BinaryIO)):
@@ -92,8 +94,10 @@ async def async_create(
9294
"""Upload a file asynchronously that can be passed as an input when running a model."""
9395

9496
if isinstance(file, (str, pathlib.Path)):
95-
with open(file, "rb") as f:
96-
return self.create(f, **params)
97+
file_path = pathlib.Path(file)
98+
params["filename"] = file_path.name
99+
with open(file_path, "rb") as f:
100+
return await self.async_create(f, **params)
97101
elif not isinstance(file, (io.IOBase, BinaryIO)):
98102
raise ValueError(
99103
"Unsupported file type. Must be a file path or file-like object."

tests/test_file.py

+97
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,108 @@
33

44
import httpx
55
import pytest
6+
import respx
67

78
import replicate
9+
from replicate.client import Client
810

911
from .conftest import skip_if_no_token
1012

13+
router = respx.Router(base_url="https://api.replicate.com/v1")
14+
15+
router.route(
16+
method="POST",
17+
path="/files",
18+
name="files.create",
19+
).mock(
20+
return_value=httpx.Response(
21+
201,
22+
json={
23+
"id": "0ZjcyLWFhZjkNGZiNmY2YzQtMThhZi0tODg4NTY0NWNlMDEy",
24+
"name": "hello.txt",
25+
"size": 14,
26+
"content_type": "text/plain",
27+
"etag": "746308829575e17c3331bbcb00c0898b",
28+
"checksums": {
29+
"md5": "746308829575e17c3331bbcb00c0898b",
30+
"sha256": "d9014c4624844aa5bac314773d6b689ad467fa4e1d1a50a1b8a99d5a95f72ff5",
31+
},
32+
"metadata": {
33+
"foo": "bar",
34+
},
35+
"urls": {
36+
"get": "https://api.replicate.com/v1/files/0ZjcyLWFhZjkNGZiNmY2YzQtMThhZi0tODg4NTY0NWNlMDEy",
37+
},
38+
"created_at": "2024-08-22T12:26:51.079Z",
39+
"expires_at": "2024-08-22T13:26:51.079Z",
40+
},
41+
)
42+
)
43+
44+
45+
@pytest.mark.asyncio
46+
@pytest.mark.parametrize("async_flag", [True, False])
47+
@pytest.mark.parametrize("use_path", [True, False])
48+
async def test_file_create(async_flag, use_path):
49+
client = Client(
50+
api_token="test-token", transport=httpx.MockTransport(router.handler)
51+
)
52+
53+
temp_dir = tempfile.mkdtemp()
54+
temp_file_path = os.path.join(temp_dir, "hello.txt")
55+
56+
try:
57+
with open(temp_file_path, "w", encoding="utf-8") as temp_file:
58+
temp_file.write("Hello, world!")
59+
60+
metadata = {"foo": "bar"}
61+
62+
if use_path:
63+
file_arg = temp_file_path
64+
if async_flag:
65+
created_file = await client.files.async_create(
66+
file_arg, metadata=metadata
67+
)
68+
else:
69+
created_file = client.files.create(file_arg, metadata=metadata)
70+
else:
71+
with open(temp_file_path, "rb") as file_arg:
72+
if async_flag:
73+
created_file = await client.files.async_create(
74+
file_arg, metadata=metadata
75+
)
76+
else:
77+
created_file = client.files.create(file_arg, metadata=metadata)
78+
79+
assert router["files.create"].called
80+
request = router["files.create"].calls[0].request
81+
82+
# Check that the request is multipart/form-data
83+
assert request.headers["Content-Type"].startswith("multipart/form-data")
84+
85+
# Check that the filename is included and matches the fixed file name
86+
assert b'filename="hello.txt"' in request.content
87+
assert b"Hello, world!" in request.content
88+
89+
# Check the response
90+
assert created_file.id == "0ZjcyLWFhZjkNGZiNmY2YzQtMThhZi0tODg4NTY0NWNlMDEy"
91+
assert created_file.name == "hello.txt"
92+
assert created_file.size == 14
93+
assert created_file.content_type == "text/plain"
94+
assert created_file.etag == "746308829575e17c3331bbcb00c0898b"
95+
assert created_file.checksums == {
96+
"md5": "746308829575e17c3331bbcb00c0898b",
97+
"sha256": "d9014c4624844aa5bac314773d6b689ad467fa4e1d1a50a1b8a99d5a95f72ff5",
98+
}
99+
assert created_file.metadata == metadata
100+
assert created_file.urls == {
101+
"get": "https://api.replicate.com/v1/files/0ZjcyLWFhZjkNGZiNmY2YzQtMThhZi0tODg4NTY0NWNlMDEy",
102+
}
103+
104+
finally:
105+
os.unlink(temp_file_path)
106+
os.rmdir(temp_dir)
107+
11108

12109
@skip_if_no_token
13110
@pytest.mark.skipif(os.environ.get("CI") is not None, reason="Do not run on CI")

0 commit comments

Comments
 (0)