@@ -189,47 +189,31 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
189189 if index == 0 : # condition
190190 tensor_constraints = [
191191 cp .Dtype .In (lambda deps : [torch .bool ]),
192- cp .Value .Ge (lambda deps , dtype , struct : - ( 2 ** 4 ) ),
193- cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
192+ cp .Value .Ge (lambda deps , dtype , struct : 0 ),
193+ cp .Value .Le (lambda deps , dtype , struct : 1 ),
194194 cp .Rank .Ge (lambda deps : 1 ),
195195 cp .Size .Ge (lambda deps , r , d : 1 ),
196196 max_size_constraint ,
197197 ]
198198 elif index == 1 : # input tensor(a)
199199 tensor_constraints = [
200- cp .Dtype .In (
201- lambda deps : [
202- torch .int8 ,
203- torch .int16 ,
204- torch .uint8 ,
205- torch .uint16 ,
206- torch .int32 ,
207- torch .float32 ,
208- ]
209- ),
200+ cp .Dtype .In (lambda deps : [torch .float32 ]),
210201 cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
211202 cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
212203 cp .Rank .Ge (lambda deps : 1 ),
213204 cp .Size .Ge (lambda deps , r , d : 1 ),
205+ cp .Size .In (lambda deps , r , d : fn .broadcast_with (deps [0 ].shape , r , d )),
214206 max_size_constraint ,
215207 ]
216208 else : # input tensor(b)
217209 tensor_constraints = [
218- cp .Dtype .In (
219- lambda deps : [
220- torch .int8 ,
221- torch .int16 ,
222- torch .uint8 ,
223- torch .uint16 ,
224- torch .int32 ,
225- torch .float32 ,
226- ]
227- ),
210+ cp .Dtype .In (lambda deps : [torch .float32 ]),
228211 cp .Dtype .Eq (lambda deps : deps [1 ].dtype ),
229212 cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
230213 cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
231214 cp .Rank .Ge (lambda deps : 1 ),
232215 cp .Size .Ge (lambda deps , r , d : 1 ),
216+ cp .Size .In (lambda deps , r , d : fn .broadcast_with (fn .broadcasted_shape (deps [0 ].shape , deps [1 ].shape ), r , d )),
233217 max_size_constraint ,
234218 ]
235219 case "embedding.default" :
0 commit comments