1
+ #include < gtest/gtest.h>
2
+
1
3
#include < torch/csrc/jit/ir/ir.h>
2
4
#include < torch/csrc/jit/ir/irparser.h>
3
5
#include < torch/csrc/jit/passes/constant_pooling.h>
4
6
#include < torch/csrc/jit/passes/constant_propagation.h>
5
7
#include < torch/csrc/jit/testing/file_check.h>
6
- #include " test/cpp/jit/test_base.h"
7
8
8
9
#include < sstream>
9
10
#include < string>
10
11
11
12
namespace torch {
12
13
namespace jit {
13
14
14
- void testConstantPooling () {
15
- {
16
- auto graph = std::make_shared<Graph>();
17
- parseIR (
18
- R"IR(
15
+ TEST (ConstantPoolingTest, Int) {
16
+ auto graph = std::make_shared<Graph>();
17
+ parseIR (
18
+ R"IR(
19
19
graph():
20
20
%8 : int = prim::Constant[value=1]()
21
21
%10 : int = prim::Constant[value=1]()
22
22
return (%8, %10)
23
23
)IR" ,
24
- &*graph);
25
- ConstantPooling (graph);
26
- testing::FileCheck ()
27
- .check_count (" prim::Constant" , 1 , /* exactly*/ true )
28
- ->run (*graph);
29
- }
30
- {
31
- auto graph = std::make_shared<Graph>();
32
- parseIR (
33
- R"IR(
24
+ &*graph);
25
+ ConstantPooling (graph);
26
+ testing::FileCheck ()
27
+ .check_count (" prim::Constant" , 1 , /* exactly*/ true )
28
+ ->run (*graph);
29
+ }
30
+
31
+ TEST (ConstantPoolingTest, PoolingAcrossBlocks) {
32
+ auto graph = std::make_shared<Graph>();
33
+ parseIR (
34
+ R"IR(
34
35
graph(%cond : Tensor):
35
36
%a : str = prim::Constant[value="bcd"]()
36
37
%3 : bool = aten::Bool(%cond)
@@ -44,17 +45,18 @@ graph(%cond : Tensor):
44
45
%7 : (str, str) = prim::TupleConstruct(%a, %b)
45
46
return (%7)
46
47
)IR" ,
47
- &*graph);
48
- ConstantPooling (graph);
49
- testing::FileCheck ()
50
- .check_count (" prim::Constant[value=\" abc\" ]" , 1 , /* exactly*/ true )
51
- ->check_count (" prim::Constant[value=\" bcd\" ]" , 1 , /* exactly*/ true )
52
- ->run (*graph);
53
- }
54
- {
55
- auto graph = std::make_shared<Graph>();
56
- parseIR (
57
- R"IR(
48
+ &*graph);
49
+ ConstantPooling (graph);
50
+ testing::FileCheck ()
51
+ .check_count (" prim::Constant[value=\" abc\" ]" , 1 , /* exactly*/ true )
52
+ ->check_count (" prim::Constant[value=\" bcd\" ]" , 1 , /* exactly*/ true )
53
+ ->run (*graph);
54
+ }
55
+
56
+ TEST (ConstantPoolingTest, PoolingDifferentDevices) {
57
+ auto graph = std::make_shared<Graph>();
58
+ parseIR (
59
+ R"IR(
58
60
graph():
59
61
%2 : int = prim::Constant[value=2]()
60
62
%1 : int = prim::Constant[value=1]()
@@ -70,22 +72,21 @@ graph():
70
72
prim::Print(%x, %y, %z)
71
73
return (%1)
72
74
)IR" ,
73
- &*graph);
74
- // three tensors created - two different devices among the three
75
- // don't have good support for parsing tensor constants
76
- ConstantPropagation (graph);
77
- ConstantPooling (graph);
78
- testing::FileCheck ()
79
- .check_count (
80
- " Float(2:1, requires_grad=0, device=cpu) = prim::Constant" ,
81
- 1 ,
82
- /* exactly*/ true )
83
- ->check_count (
84
- " Long(2:1, requires_grad=0, device=cpu) = prim::Constant" ,
85
- 1 ,
86
- /* exactly*/ true )
87
- ->run (*graph);
88
- }
75
+ &*graph);
76
+ // three tensors created - two different devices among the three
77
+ // don't have good support for parsing tensor constants
78
+ ConstantPropagation (graph);
79
+ ConstantPooling (graph);
80
+ testing::FileCheck ()
81
+ .check_count (
82
+ " Float(2:1, requires_grad=0, device=cpu) = prim::Constant" ,
83
+ 1 ,
84
+ /* exactly*/ true )
85
+ ->check_count (
86
+ " Long(2:1, requires_grad=0, device=cpu) = prim::Constant" ,
87
+ 1 ,
88
+ /* exactly*/ true )
89
+ ->run (*graph);
89
90
}
90
91
} // namespace jit
91
92
} // namespace torch
0 commit comments