@@ -15,115 +15,135 @@ const kExampleInputDescriptor = {
15
15
16
16
const tests = [
17
17
{
18
- name :
19
- '[where] Throw if the condition data type is not uint8.' ,
18
+ name : '[where] Throw if the condition data type is not uint8.' ,
20
19
condition : { dataType : 'float32' , dimensions : [ 2 , 4 ] } ,
21
- input : { dataType : 'float32' , dimensions : [ 2 , 4 ] } ,
22
- other : { dataType : 'float32' , dimensions : [ 2 , 4 ] } ,
20
+ trueValue : { dataType : 'float32' , dimensions : [ 2 , 4 ] } ,
21
+ falseValue : { dataType : 'float32' , dimensions : [ 2 , 4 ] } ,
23
22
} ,
24
23
{
25
24
name :
26
- '[where] Throw if the data types of input and other do not match' ,
25
+ '[where] Throw if the data types of trueValue and falseValue do not match' ,
27
26
condition : { dataType : 'uint8' , dimensions : [ 2 , 4 ] } ,
28
- input : { dataType : 'float16' , dimensions : [ 2 , 4 ] } ,
29
- other : { dataType : 'float32' , dimensions : [ 2 , 4 ] } ,
27
+ trueValue : { dataType : 'float16' , dimensions : [ 2 , 4 ] } ,
28
+ falseValue : { dataType : 'float32' , dimensions : [ 2 , 4 ] } ,
30
29
} ,
31
30
{
32
31
name :
33
- '[where] Throw if the shapes of input and other are not broadcastable' ,
32
+ '[where] Throw if the shapes of trueValue and falseValue are not broadcastable' ,
34
33
condition : { dataType : 'uint8' , dimensions : [ 2 , 4 ] } ,
35
- input : { dataType : 'float32' , dimensions : [ 2 , 4 ] } ,
36
- other : { dataType : 'float32' , dimensions : [ 2 , 3 ] } ,
34
+ trueValue : { dataType : 'float32' , dimensions : [ 2 , 4 ] } ,
35
+ falseValue : { dataType : 'float32' , dimensions : [ 2 , 3 ] } ,
37
36
} ,
38
37
{
39
- name :
40
- '[where] Throw if the condition shape is not broadcastable' ,
38
+ name : '[where] Throw if the condition shape is not broadcastable' ,
41
39
condition : { dataType : 'uint8' , dimensions : [ 2 , 4 ] } ,
42
- input : { dataType : 'float32' , dimensions : [ 2 , 3 ] } ,
43
- other : { dataType : 'float32' , dimensions : [ 2 , 1 ] } ,
40
+ trueValue : { dataType : 'float32' , dimensions : [ 2 , 3 ] } ,
41
+ falseValue : { dataType : 'float32' , dimensions : [ 2 , 1 ] } ,
44
42
} ,
45
43
{
46
44
name :
47
- '[where] Test building where with 2-D condition, 2-D input and 2-D other using broadcast' ,
45
+ '[where] Test building where with 2-D condition, 2-D trueValue and 2-D falseValue using broadcast' ,
48
46
condition : { dataType : 'uint8' , dimensions : [ 2 , 1 ] } ,
49
- input : { dataType : 'float32' , dimensions : [ 2 , 4 ] } ,
50
- other : { dataType : 'float32' , dimensions : [ 2 , 4 ] } ,
47
+ trueValue : { dataType : 'float32' , dimensions : [ 2 , 4 ] } ,
48
+ falseValue : { dataType : 'float32' , dimensions : [ 2 , 4 ] } ,
51
49
output : { dataType : 'float32' , dimensions : [ 2 , 4 ] } ,
52
50
} ,
53
51
{
54
52
name :
55
- '[where] Test building where with 2-D condition, 2-D input and 3-D other using broadcast' ,
53
+ '[where] Test building where with 2-D condition, 2-D trueValue and 3-D falseValue using broadcast' ,
56
54
condition : { dataType : 'uint8' , dimensions : [ 1 , 4 ] } ,
57
- input : { dataType : 'float32 ' , dimensions : [ 3 , 4 ] } ,
58
- other : { dataType : 'float32 ' , dimensions : [ 2 , 3 , 4 ] } ,
59
- output : { dataType : 'float32 ' , dimensions : [ 2 , 3 , 4 ] } ,
55
+ trueValue : { dataType : 'float16 ' , dimensions : [ 3 , 4 ] } ,
56
+ falseValue : { dataType : 'float16 ' , dimensions : [ 2 , 3 , 4 ] } ,
57
+ output : { dataType : 'float16 ' , dimensions : [ 2 , 3 , 4 ] } ,
60
58
} ,
61
59
{
62
60
name :
63
- '[where] Test building where with 3-D condition, 3-D input and 2-D other using broadcast' ,
61
+ '[where] Test building where with 3-D condition, 3-D trueValue and 2-D falseValue using broadcast' ,
64
62
condition : { dataType : 'uint8' , dimensions : [ 2 , 1 , 4 ] } ,
65
- input : { dataType : 'float32 ' , dimensions : [ 2 , 3 , 4 ] } ,
66
- other : { dataType : 'float32 ' , dimensions : [ 1 , 4 ] } ,
67
- output : { dataType : 'float32 ' , dimensions : [ 2 , 3 , 4 ] } ,
63
+ trueValue : { dataType : 'int32 ' , dimensions : [ 2 , 3 , 4 ] } ,
64
+ falseValue : { dataType : 'int32 ' , dimensions : [ 1 , 4 ] } ,
65
+ output : { dataType : 'int32 ' , dimensions : [ 2 , 3 , 4 ] } ,
68
66
} ,
69
67
{
70
68
name :
71
- '[where] Test building where with 4-D condition, 3-D input and 2-D other using broadcast' ,
69
+ '[where] Test building where with 4-D condition, 3-D trueValue and 2-D falseValue using broadcast' ,
72
70
condition : { dataType : 'uint8' , dimensions : [ 2 , 3 , 4 , 5 ] } ,
73
- input : { dataType : 'float32 ' , dimensions : [ 3 , 4 , 5 ] } ,
74
- other : { dataType : 'float32 ' , dimensions : [ 4 , 5 ] } ,
75
- output : { dataType : 'float32 ' , dimensions : [ 2 , 3 , 4 , 5 ] } ,
71
+ trueValue : { dataType : 'uint32 ' , dimensions : [ 3 , 4 , 5 ] } ,
72
+ falseValue : { dataType : 'uint32 ' , dimensions : [ 4 , 5 ] } ,
73
+ output : { dataType : 'uint32 ' , dimensions : [ 2 , 3 , 4 , 5 ] } ,
76
74
}
77
75
] ;
78
76
79
77
tests . forEach (
80
78
test => promise_test ( async t => {
79
+ for ( let operand of [ test . condition , test . trueValue , test . falseValue ] ) {
80
+ if ( ! context . opSupportLimits ( ) . input . dataTypes . includes (
81
+ operand . dataType ) ) {
82
+ assert_throws_js ( TypeError , ( ) => builder . input ( 'input' , {
83
+ dataType : operand . dataType ,
84
+ dimensions : operand . dimensions
85
+ } ) ) ;
86
+ return ;
87
+ }
88
+ }
89
+
81
90
const condition = builder . input ( 'condition' , {
82
91
dataType : test . condition . dataType ,
83
92
dimensions : test . condition . dimensions
84
93
} ) ;
85
- const input = builder . input (
86
- 'input' ,
87
- { dataType : test . input . dataType , dimensions : test . input . dimensions } ) ;
88
- const other = builder . input (
89
- 'other' ,
90
- { dataType : test . other . dataType , dimensions : test . other . dimensions } ) ;
91
- if ( test . output ) {
92
- const output = builder . where ( condition , input , other ) ;
94
+ const trueValue = builder . input ( 'trueValue' , {
95
+ dataType : test . trueValue . dataType ,
96
+ dimensions : test . trueValue . dimensions
97
+ } ) ;
98
+ const falseValue = builder . input ( 'falseValue' , {
99
+ dataType : test . falseValue . dataType ,
100
+ dimensions : test . falseValue . dimensions
101
+ } ) ;
102
+ if ( test . output &&
103
+ context . opSupportLimits ( ) . where . condition . dataTypes . includes (
104
+ test . condition . dataType ) &&
105
+ context . opSupportLimits ( ) . where . trueValue . dataTypes . includes (
106
+ test . trueValue . dataType ) &&
107
+ context . opSupportLimits ( ) . where . falseValue . dataTypes . includes (
108
+ test . falseValue . dataType ) ) {
109
+ const output = builder . where ( condition , trueValue , falseValue ) ;
93
110
assert_equals ( output . dataType ( ) , test . output . dataType ) ;
94
111
assert_array_equals ( output . shape ( ) , test . output . dimensions ) ;
95
112
} else {
96
113
assert_throws_js (
97
- TypeError , ( ) => builder . where ( condition , input , other ) ) ;
114
+ TypeError , ( ) => builder . where ( condition , trueValue , falseValue ) ) ;
98
115
}
99
116
} , test . name ) ) ;
100
117
101
118
multi_builder_test ( async ( t , builder , otherBuilder ) => {
102
119
const conditionFromOtherBuilder =
103
120
otherBuilder . input ( 'condition' , kExampleConditionDescriptor ) ;
104
121
105
- const input = builder . input ( 'input ' , kExampleInputDescriptor ) ;
106
- const other = builder . input ( 'other ' , kExampleInputDescriptor ) ;
122
+ const trueValue = builder . input ( 'trueValue ' , kExampleInputDescriptor ) ;
123
+ const falseValue = builder . input ( 'falseValue ' , kExampleInputDescriptor ) ;
107
124
assert_throws_js (
108
- TypeError , ( ) => builder . where ( conditionFromOtherBuilder , input , other ) ) ;
125
+ TypeError ,
126
+ ( ) => builder . where ( conditionFromOtherBuilder , trueValue , falseValue ) ) ;
109
127
} , '[where] throw if condition is from another builder' ) ;
110
128
111
129
multi_builder_test ( async ( t , builder , otherBuilder ) => {
112
- const inputFromOtherBuilder =
113
- otherBuilder . input ( 'input ' , kExampleInputDescriptor ) ;
130
+ const trueValueFromOtherBuilder =
131
+ otherBuilder . input ( 'trueValue ' , kExampleInputDescriptor ) ;
114
132
115
133
const condition = builder . input ( 'condition' , kExampleConditionDescriptor ) ;
116
- const other = builder . input ( 'other ' , kExampleInputDescriptor ) ;
134
+ const falseValue = builder . input ( 'falseValue ' , kExampleInputDescriptor ) ;
117
135
assert_throws_js (
118
- TypeError , ( ) => builder . where ( condition , inputFromOtherBuilder , other ) ) ;
119
- } , '[where] throw if input is from another builder' ) ;
136
+ TypeError ,
137
+ ( ) => builder . where ( condition , trueValueFromOtherBuilder , falseValue ) ) ;
138
+ } , '[where] throw if trueValue is from another builder' ) ;
120
139
121
140
multi_builder_test ( async ( t , builder , otherBuilder ) => {
122
- const otherFromOtherBuilder =
123
- otherBuilder . input ( 'other ' , kExampleInputDescriptor ) ;
141
+ const falseValueFromOtherBuilder =
142
+ otherBuilder . input ( 'falseValue ' , kExampleInputDescriptor ) ;
124
143
125
144
const condition = builder . input ( 'condition' , kExampleConditionDescriptor ) ;
126
- const input = builder . input ( 'input ' , kExampleInputDescriptor ) ;
145
+ const trueValue = builder . input ( 'trueValue ' , kExampleInputDescriptor ) ;
127
146
assert_throws_js (
128
- TypeError , ( ) => builder . where ( condition , input , otherFromOtherBuilder ) ) ;
129
- } , '[where] throw if other is from another builder' ) ;
147
+ TypeError ,
148
+ ( ) => builder . where ( condition , trueValue , falseValueFromOtherBuilder ) ) ;
149
+ } , '[where] throw if falseValue is from another builder' ) ;
0 commit comments