Skip to content

Commit 6d21d5f

Browse files
suofacebook-github-bot
authored andcommitted
gtest-ify JIT tests, through the letter c (pytorch#45249)
Summary: Pull Request resolved: pytorch#45249 Reland of pytorch#45055 and pytorch#45020 See pytorch#45018 for context. Test Plan: Imported from OSS Reviewed By: jamesr66a Differential Revision: D23892645 Pulled By: suo fbshipit-source-id: e7fe58d5e1a5a0c44f4e2aec9694145afabde0fd
1 parent 29dc3c5 commit 6d21d5f

16 files changed

+641
-634
lines changed

test/cpp/jit/CMakeLists.txt

+6-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ set(JIT_TEST_ROOT ${TORCH_ROOT}/test/cpp/jit)
22

33
# Build separate libraries the define custom classes/operators used from our Python tests.
44
# These are intended to be used with torch.ops.load_library() in our Python test suite.
5-
add_library(torchbind_test SHARED ${JIT_TEST_ROOT}/test_custom_class.cpp)
5+
add_library(torchbind_test SHARED
6+
${JIT_TEST_ROOT}/test_custom_class_registrations.h
7+
${JIT_TEST_ROOT}/test_custom_class_registrations.cpp
8+
)
69
target_link_libraries(torchbind_test torch)
710

811
add_library(jitbackend_test SHARED ${JIT_TEST_ROOT}/test_backend.cpp)
@@ -30,6 +33,8 @@ set(JIT_TEST_SRCS
3033
${JIT_TEST_ROOT}/test_cleanup_passes.cpp
3134
${JIT_TEST_ROOT}/test_create_autodiff_subgraphs.cpp
3235
${JIT_TEST_ROOT}/test_custom_class.cpp
36+
${JIT_TEST_ROOT}/test_custom_class_registrations.h
37+
${JIT_TEST_ROOT}/test_custom_class_registrations.cpp
3338
${JIT_TEST_ROOT}/test_custom_operators.cpp
3439
${JIT_TEST_ROOT}/test_dce.cpp
3540
${JIT_TEST_ROOT}/test_fuser.cpp

test/cpp/jit/test_autodiff.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
#include "test/cpp/jit/test_base.h"
1+
#include <gtest/gtest.h>
2+
23
#include "test/cpp/jit/test_utils.h"
34
#include "torch/csrc/jit/frontend/tracer.h"
45
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
@@ -83,7 +84,7 @@ variable_list grad(
8384
fmap(inputs, get_edge));
8485
}
8586

86-
void testADFormulas() {
87+
TEST(AutodiffTest, ADFormulas) {
8788
const auto cast = [](const Variable& v) {
8889
return static_cast<at::Tensor>(v);
8990
};
@@ -174,7 +175,7 @@ void testADFormulas() {
174175
}
175176
}
176177

177-
void testDifferentiate() {
178+
TEST(AutodiffTest, Differentiate) {
178179
// Note: can't use IRParser for this test due to issue #23989
179180
auto graph = std::make_shared<Graph>();
180181
std::vector<int64_t> sizes{2, 3, 4};
@@ -229,7 +230,7 @@ void testDifferentiate() {
229230
->run(*grad_spec.df);
230231
}
231232

232-
void testDifferentiateWithRequiresGrad() {
233+
TEST(AutodiffTest, DifferentiateWithRequiresGrad) {
233234
const auto graph_string = R"IR(
234235
graph(%0 : Tensor,
235236
%1 : Tensor):

test/cpp/jit/test_class_import.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
#include <test/cpp/jit/test_base.h>
2-
#include <test/cpp/jit/test_utils.h>
1+
#include <gtest/gtest.h>
32

43
#include <ATen/core/qualified_name.h>
4+
#include <test/cpp/jit/test_utils.h>
55
#include <torch/csrc/jit/frontend/resolver.h>
66
#include <torch/csrc/jit/serialization/import_source.h>
77
#include <torch/torch.h>
@@ -45,7 +45,7 @@ static void import_libs(
4545
si.loadType(QualifiedName(class_name));
4646
}
4747

48-
void testClassImport() {
48+
TEST(ClassImportTest, Basic) {
4949
auto cu1 = std::make_shared<CompilationUnit>();
5050
auto cu2 = std::make_shared<CompilationUnit>();
5151
std::vector<at::IValue> constantTable;
@@ -80,7 +80,7 @@ void testClassImport() {
8080
ASSERT_FALSE(c);
8181
}
8282

83-
void testScriptObject() {
83+
TEST(ClassImportTest, ScriptObject) {
8484
Module m1("m1");
8585
Module m2("m2");
8686
std::vector<at::IValue> constantTable;
@@ -114,7 +114,7 @@ def __init__(self, x):
114114
return x
115115
)JIT";
116116

117-
void testClassDerive() {
117+
TEST(ClassImportTest, ClassDerive) {
118118
auto cu = std::make_shared<CompilationUnit>();
119119
auto cls = ClassType::create("foo.bar", cu);
120120
const auto self = SimpleSelf(cls);
@@ -142,7 +142,7 @@ class FooBar1234(Module):
142142
return (self.f).top()
143143
)JIT";
144144

145-
void testSaveLoadTorchbind() {
145+
TEST(ClassImportTest, CustomClass) {
146146
auto cu1 = std::make_shared<CompilationUnit>();
147147
std::vector<at::IValue> constantTable;
148148
// Import different versions of FooTest into two namespaces.

test/cpp/jit/test_class_parser.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#include <gtest/gtest.h>
2+
13
#include <test/cpp/jit/test_base.h>
24
#include <torch/csrc/jit/frontend/parser.h>
35
#include <torch/csrc/jit/frontend/resolver.h>
@@ -15,7 +17,7 @@ const auto testSource = R"JIT(
1517
an_attribute : Tensor
1618
)JIT";
1719

18-
void testClassParser() {
20+
TEST(ClassParserTest, Basic) {
1921
Parser p(std::make_shared<Source>(testSource));
2022
std::vector<Def> definitions;
2123
std::vector<Resolver> resolvers;

test/cpp/jit/test_cleanup_passes.cpp

+18-19
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
1+
#include <gtest/gtest.h>
2+
13
#include <torch/csrc/jit/frontend/ir_emitter.h>
24
#include <torch/csrc/jit/ir/ir.h>
35
#include <torch/csrc/jit/ir/irparser.h>
46
#include <torch/csrc/jit/testing/file_check.h>
5-
#include "test/cpp/jit/test_base.h"
67

78
namespace torch {
89
namespace jit {
910

10-
void testCleanUpPasses() {
11+
TEST(CleanupPassTest, Basic) {
1112
// Tests stability of clean up passes when dealing with constant pooling
1213
// and constant propagation.
13-
{
14-
auto graph = std::make_shared<Graph>();
15-
parseIR(
16-
R"IR(
14+
auto graph = std::make_shared<Graph>();
15+
parseIR(
16+
R"IR(
1717
graph(%cond.1 : Tensor,
1818
%suffix.1 : str):
1919
%3 : bool = aten::Bool(%cond.1) # o.py:6:7
@@ -31,20 +31,19 @@ graph(%cond.1 : Tensor,
3131
-> (%12)
3232
return (%25)
3333
)IR",
34-
&*graph);
35-
runCleanupPasses(graph);
36-
testing::FileCheck()
37-
.check_count(
38-
"prim::Constant[value=\"same string with a twist\"]",
39-
1,
40-
/*exactly=*/true)
41-
->run(*graph);
34+
&*graph);
35+
runCleanupPasses(graph);
36+
testing::FileCheck()
37+
.check_count(
38+
"prim::Constant[value=\"same string with a twist\"]",
39+
1,
40+
/*exactly=*/true)
41+
->run(*graph);
4242

43-
auto graph_after_pass_once = graph->toString();
44-
runCleanupPasses(graph);
45-
auto graph_after_pass_twice = graph->toString();
46-
ASSERT_EQ(graph_after_pass_once, graph_after_pass_twice);
47-
}
43+
auto graph_after_pass_once = graph->toString();
44+
runCleanupPasses(graph);
45+
auto graph_after_pass_twice = graph->toString();
46+
ASSERT_EQ(graph_after_pass_once, graph_after_pass_twice);
4847
}
4948
} // namespace jit
5049
} // namespace torch

test/cpp/jit/test_code_template.cpp

+24-26
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
#include "test/cpp/jit/test_base.h"
2-
#include "test/cpp/jit/test_utils.h"
1+
#include <gtest/gtest.h>
32

3+
#include "test/cpp/jit/test_utils.h"
44
#include "torch/csrc/jit/frontend/code_template.h"
55

66
namespace torch {
@@ -33,31 +33,29 @@ static const auto ct_expect = R"(
3333
int notest(int a)
3434
)";
3535

36-
void testCodeTemplate() {
37-
{
38-
TemplateEnv e;
39-
e.s("hi", "foo");
40-
e.v("what", {"is", "this"});
41-
TemplateEnv c(e);
42-
c.s("hi", "foo2");
43-
ASSERT_EQ(e.s("hi"), "foo");
44-
ASSERT_EQ(c.s("hi"), "foo2");
45-
ASSERT_EQ(e.v("what")[0], "is");
46-
}
36+
TEST(TestCodeTemplate, Copying) {
37+
TemplateEnv e;
38+
e.s("hi", "foo");
39+
e.v("what", {"is", "this"});
40+
TemplateEnv c(e);
41+
c.s("hi", "foo2");
42+
ASSERT_EQ(e.s("hi"), "foo");
43+
ASSERT_EQ(c.s("hi"), "foo2");
44+
ASSERT_EQ(e.v("what")[0], "is");
45+
}
4746

48-
{
49-
TemplateEnv e;
50-
e.v("args", {"hi", "8"});
51-
e.v("bar", {"what\non many\nlines...", "7"});
52-
e.s("a", "3");
53-
e.s("b", "4");
54-
e.v("stuff", {"things...", "others"});
55-
e.v("empty", {});
56-
auto s = ct.format(e);
57-
// std::cout << "'" << s << "'\n";
58-
// std::cout << "'" << ct_expect << "'\n";
59-
ASSERT_EQ(s, ct_expect);
60-
}
47+
TEST(TestCodeTemplate, Formatting) {
48+
TemplateEnv e;
49+
e.v("args", {"hi", "8"});
50+
e.v("bar", {"what\non many\nlines...", "7"});
51+
e.s("a", "3");
52+
e.s("b", "4");
53+
e.v("stuff", {"things...", "others"});
54+
e.v("empty", {});
55+
auto s = ct.format(e);
56+
// std::cout << "'" << s << "'\n";
57+
// std::cout << "'" << ct_expect << "'\n";
58+
ASSERT_EQ(s, ct_expect);
6159
}
6260

6361
} // namespace jit
+44-43
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,37 @@
1+
#include <gtest/gtest.h>
2+
13
#include <torch/csrc/jit/ir/ir.h>
24
#include <torch/csrc/jit/ir/irparser.h>
35
#include <torch/csrc/jit/passes/constant_pooling.h>
46
#include <torch/csrc/jit/passes/constant_propagation.h>
57
#include <torch/csrc/jit/testing/file_check.h>
6-
#include "test/cpp/jit/test_base.h"
78

89
#include <sstream>
910
#include <string>
1011

1112
namespace torch {
1213
namespace jit {
1314

14-
void testConstantPooling() {
15-
{
16-
auto graph = std::make_shared<Graph>();
17-
parseIR(
18-
R"IR(
15+
TEST(ConstantPoolingTest, Int) {
16+
auto graph = std::make_shared<Graph>();
17+
parseIR(
18+
R"IR(
1919
graph():
2020
%8 : int = prim::Constant[value=1]()
2121
%10 : int = prim::Constant[value=1]()
2222
return (%8, %10)
2323
)IR",
24-
&*graph);
25-
ConstantPooling(graph);
26-
testing::FileCheck()
27-
.check_count("prim::Constant", 1, /*exactly*/ true)
28-
->run(*graph);
29-
}
30-
{
31-
auto graph = std::make_shared<Graph>();
32-
parseIR(
33-
R"IR(
24+
&*graph);
25+
ConstantPooling(graph);
26+
testing::FileCheck()
27+
.check_count("prim::Constant", 1, /*exactly*/ true)
28+
->run(*graph);
29+
}
30+
31+
TEST(ConstantPoolingTest, PoolingAcrossBlocks) {
32+
auto graph = std::make_shared<Graph>();
33+
parseIR(
34+
R"IR(
3435
graph(%cond : Tensor):
3536
%a : str = prim::Constant[value="bcd"]()
3637
%3 : bool = aten::Bool(%cond)
@@ -44,17 +45,18 @@ graph(%cond : Tensor):
4445
%7 : (str, str) = prim::TupleConstruct(%a, %b)
4546
return (%7)
4647
)IR",
47-
&*graph);
48-
ConstantPooling(graph);
49-
testing::FileCheck()
50-
.check_count("prim::Constant[value=\"abc\"]", 1, /*exactly*/ true)
51-
->check_count("prim::Constant[value=\"bcd\"]", 1, /*exactly*/ true)
52-
->run(*graph);
53-
}
54-
{
55-
auto graph = std::make_shared<Graph>();
56-
parseIR(
57-
R"IR(
48+
&*graph);
49+
ConstantPooling(graph);
50+
testing::FileCheck()
51+
.check_count("prim::Constant[value=\"abc\"]", 1, /*exactly*/ true)
52+
->check_count("prim::Constant[value=\"bcd\"]", 1, /*exactly*/ true)
53+
->run(*graph);
54+
}
55+
56+
TEST(ConstantPoolingTest, PoolingDifferentDevices) {
57+
auto graph = std::make_shared<Graph>();
58+
parseIR(
59+
R"IR(
5860
graph():
5961
%2 : int = prim::Constant[value=2]()
6062
%1 : int = prim::Constant[value=1]()
@@ -70,22 +72,21 @@ graph():
7072
prim::Print(%x, %y, %z)
7173
return (%1)
7274
)IR",
73-
&*graph);
74-
// three tensors created - two different devices among the three
75-
// don't have good support for parsing tensor constants
76-
ConstantPropagation(graph);
77-
ConstantPooling(graph);
78-
testing::FileCheck()
79-
.check_count(
80-
"Float(2:1, requires_grad=0, device=cpu) = prim::Constant",
81-
1,
82-
/*exactly*/ true)
83-
->check_count(
84-
"Long(2:1, requires_grad=0, device=cpu) = prim::Constant",
85-
1,
86-
/*exactly*/ true)
87-
->run(*graph);
88-
}
75+
&*graph);
76+
// three tensors created - two different devices among the three
77+
// don't have good support for parsing tensor constants
78+
ConstantPropagation(graph);
79+
ConstantPooling(graph);
80+
testing::FileCheck()
81+
.check_count(
82+
"Float(2:1, requires_grad=0, device=cpu) = prim::Constant",
83+
1,
84+
/*exactly*/ true)
85+
->check_count(
86+
"Long(2:1, requires_grad=0, device=cpu) = prim::Constant",
87+
1,
88+
/*exactly*/ true)
89+
->run(*graph);
8990
}
9091
} // namespace jit
9192
} // namespace torch

test/cpp/jit/test_create_autodiff_subgraphs.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
#include "test/cpp/jit/test_base.h"
1+
#include <gtest/gtest.h>
2+
23
#include "test/cpp/jit/test_utils.h"
34

45
#include "torch/csrc/jit/passes/create_autodiff_subgraphs.h"
56

67
namespace torch {
78
namespace jit {
89

9-
void testCreateAutodiffSubgraphs() {
10+
TEST(CreateAutodiffSubgraphsTest, Basic) {
1011
auto graph = build_lstm();
1112
CreateAutodiffSubgraphs(graph, /*threshold=*/2);
1213
// all of the ops are within the DifferentiableGraph

0 commit comments

Comments
 (0)