Skip to content

Commit 6bfffd9

Browse files
authored
Fix: Handle no valid python files in the directory. (#83)
If you run torchfix script on a directory without Python files, it will terminate with an error. This modifies the script to just do nothing in such cases, without terminating with an error. ### Testing: Added test "test_no_python_files" ### Without fix the new test fails: assert 1 == 0 ...... raise Exception("Must have at least one job to process!")\nException: Must have at least one job to process!\n').returncode ......
1 parent 87289c1 commit 6bfffd9

File tree

2 files changed

+32
-12
lines changed

2 files changed

+32
-12
lines changed

tests/test_torchfix.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1+
import logging
2+
import subprocess
13
from pathlib import Path
4+
5+
import libcst.codemod as codemod
26
from torchfix.torchfix import (
3-
TorchChecker,
4-
TorchCodemod,
5-
TorchCodemodConfig,
67
DISABLED_BY_DEFAULT,
78
expand_error_codes,
8-
GET_ALL_VISITORS,
99
GET_ALL_ERROR_CODES,
10+
GET_ALL_VISITORS,
1011
process_error_code_str,
12+
TorchChecker,
13+
TorchCodemod,
14+
TorchCodemodConfig,
1115
)
12-
import logging
13-
import libcst.codemod as codemod
1416

1517
FIXTURES_PATH = Path(__file__).absolute().parent / "fixtures"
1618
LOGGER = logging.getLogger(__name__)
@@ -103,3 +105,18 @@ def test_errorcodes_distinct():
103105

104106
def test_parse_error_code_str(case, expected):
105107
assert process_error_code_str(case) == expected
108+
109+
110+
def test_no_python_files(tmp_path):
111+
# Create a temporary directory with no Python files
112+
non_python_file = tmp_path / "not_a_python_file.txt"
113+
non_python_file.write_text("This is not a Python file")
114+
115+
# Run torchfix on the temporary directory
116+
result = subprocess.run(
117+
["python3", "-m", "torchfix", str(tmp_path)],
118+
capture_output=True,
119+
text=True,
120+
)
121+
# Check that the script exits successfully
122+
assert result.returncode == 0

torchfix/__main__.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
import argparse
2-
import libcst.codemod as codemod
32

43
import contextlib
54
import ctypes
6-
import sys
75
import io
6+
import sys
7+
8+
import libcst.codemod as codemod
9+
10+
from .common import CYAN, ENDC
811

912
from .torchfix import (
10-
TorchCodemod,
11-
TorchCodemodConfig,
1213
__version__ as TorchFixVersion,
1314
DISABLED_BY_DEFAULT,
1415
GET_ALL_ERROR_CODES,
1516
process_error_code_str,
17+
TorchCodemod,
18+
TorchCodemodConfig,
1619
)
17-
from .common import CYAN, ENDC
1820

1921

2022
# Should get rid of this code eventually.
@@ -83,7 +85,6 @@ def _parse_args() -> argparse.Namespace:
8385

8486
def main() -> None:
8587
args = _parse_args()
86-
8788
files = codemod.gather_files(args.path)
8889

8990
# Filter out files that don't have "torch" string in them.
@@ -97,6 +98,8 @@ def main() -> None:
9798
torch_files.append(file)
9899
break
99100

101+
if not torch_files:
102+
return
100103
config = TorchCodemodConfig()
101104
config.select = list(process_error_code_str(args.select))
102105
command_instance = TorchCodemod(codemod.CodemodContext(), config)

0 commit comments

Comments
 (0)