Skip to content

Commit fc5db42

Browse files
Rong Rong (AI Infra)facebook-github-bot
Rong Rong (AI Infra)
authored andcommitted
[BE] replace unittest.main with run_tests (pytorch#50451)
Summary: fix pytorch#50448. This replaces all `test/*.py` files with run_tests(). This PR does not address test files in the subdirectories because they seems unrelated. Pull Request resolved: pytorch#50451 Reviewed By: janeyx99 Differential Revision: D25899924 Pulled By: walterddr fbshipit-source-id: f7c861f0096624b2791ad6ef6a16b1c4895cce71
1 parent a4383a6 commit fc5db42

8 files changed

+24
-24
lines changed

test/test_expecttest.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import torch.testing._internal.expecttest as expecttest
1+
from torch.testing._internal import expecttest
2+
from torch.testing._internal.common_utils import TestCase, run_tests
23

3-
import unittest
44
import string
55
import textwrap
66
import doctest
@@ -17,7 +17,7 @@ def text_lineno(draw):
1717
return (t, lineno)
1818

1919

20-
class TestExpectTest(expecttest.TestCase):
20+
class TestExpectTest(TestCase):
2121
@hypothesis.given(text_lineno())
2222
def test_nth_line_ref(self, t_lineno):
2323
t, lineno = t_lineno
@@ -103,4 +103,4 @@ def load_tests(loader, tests, ignore):
103103

104104

105105
if __name__ == '__main__':
106-
unittest.main()
106+
run_tests()

test/test_jit_disabled.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
import unittest
21
import sys
32
import os
43
import contextlib
54
import subprocess
6-
from torch.testing._internal.common_utils import TemporaryFileName
5+
from torch.testing._internal.common_utils import TestCase, run_tests, TemporaryFileName
76

87

98
@contextlib.contextmanager
@@ -16,7 +15,7 @@ def _jit_disabled():
1615
os.environ["PYTORCH_JIT"] = cur_env
1716

1817

19-
class TestJitDisabled(unittest.TestCase):
18+
class TestJitDisabled(TestCase):
2019
"""
2120
These tests are separate from the rest of the JIT tests because we need
2221
run a new subprocess and `import torch` with the correct environment
@@ -91,4 +90,4 @@ def forward(self, input):
9190
self.compare_enabled_disabled(_program_string)
9291

9392
if __name__ == '__main__':
94-
unittest.main()
93+
run_tests()

test/test_mobile_optimizer.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch.nn as nn
44
import torch.backends.xnnpack
55
import torch.utils.bundled_inputs
6+
from torch.testing._internal.common_utils import TestCase, run_tests
67
from torch.testing._internal.jit_utils import get_forward, get_forward_graph
78
from torch.utils.mobile_optimizer import *
89
from torch.nn import functional as F
@@ -12,7 +13,7 @@
1213

1314
FileCheck = torch._C.FileCheck
1415

15-
class TestOptimizer(unittest.TestCase):
16+
class TestOptimizer(TestCase):
1617

1718
@unittest.skipUnless(torch.backends.xnnpack.enabled,
1819
" XNNPACK must be enabled for these tests."
@@ -430,4 +431,4 @@ def _quant_script_and_optimize(model):
430431

431432

432433
if __name__ == '__main__':
433-
unittest.main()
434+
run_tests()

test/test_namedtuple_return_api.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import os
22
import re
33
import yaml
4-
import unittest
54
import textwrap
65
import torch
6+
7+
from torch.testing._internal.common_utils import TestCase, run_tests
78
from collections import namedtuple
89

910

@@ -17,7 +18,7 @@
1718
}
1819

1920

20-
class TestNamedTupleAPI(unittest.TestCase):
21+
class TestNamedTupleAPI(TestCase):
2122

2223
def test_native_functions_yaml(self):
2324
operators_found = set()
@@ -108,4 +109,4 @@ def check_namedtuple(tup, names):
108109

109110

110111
if __name__ == '__main__':
111-
unittest.main()
112+
run_tests()

test/test_overrides.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import torch
22
import numpy as np
3-
import unittest
43
import inspect
54
import functools
65
import pprint
76

8-
from torch.testing._internal.common_utils import TestCase
7+
from torch.testing._internal.common_utils import TestCase, run_tests
98
from torch.overrides import (
109
handle_torch_function,
1110
has_torch_function,
@@ -880,4 +879,4 @@ def f(a):
880879
self.assertEqual(f(A()), -1)
881880

882881
if __name__ == '__main__':
883-
unittest.main()
882+
run_tests()

test/test_package.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from unittest import main, skipIf
2-
from torch.testing._internal.common_utils import TestCase, IS_WINDOWS
1+
from unittest import skipIf
2+
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS
33
from tempfile import NamedTemporaryFile
44
from torch.package import PackageExporter, PackageImporter
55
from pathlib import Path
@@ -392,4 +392,4 @@ def test_pickle_mocked(self):
392392

393393

394394
if __name__ == '__main__':
395-
main()
395+
run_tests()

test/test_show_pickle.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import torch
55
import torch.utils.show_pickle
66

7-
from torch.testing._internal.common_utils import IS_WINDOWS
7+
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS
88

9-
class TestShowPickle(unittest.TestCase):
9+
class TestShowPickle(TestCase):
1010

1111
@unittest.skipIf(IS_WINDOWS, "Can't re-open temp file on Windows")
1212
def test_scripted_model(self):
@@ -31,4 +31,4 @@ def forward(self, x):
3131

3232

3333
if __name__ == '__main__':
34-
unittest.main()
34+
run_tests()

test/test_tensorexpr.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch import nn
55
import unittest
66

7-
from torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs
7+
from torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs, run_tests
88

99
from torch.testing._internal.te_utils import CudaCodeGenCreated, CudaCodeGenExecuted, \
1010
LLVMCodeGenExecuted, SimpleIREvalExecuted
@@ -1647,4 +1647,4 @@ def foo(a, b, c):
16471647
self.assertEqual(ref, exp)
16481648

16491649
if __name__ == '__main__':
1650-
unittest.main()
1650+
run_tests()

0 commit comments

Comments
 (0)