@@ -285,18 +285,17 @@ def _emit_list(self, val: List[_Argument], val_type: _SchemaType) -> EValue:
285
285
286
286
NOTE: When symbool and symfloat are supported bool and float lists will be stored boxed.
287
287
"""
288
- elem_type = type (val_type )
289
288
290
- if elem_type == torch .BoolType :
289
+ if isinstance ( val_type , torch .BoolType ) :
291
290
return EValue (BoolList (typing .cast (List [bool ], val )))
292
291
293
- if elem_type == torch .IntType :
292
+ if isinstance ( val_type , torch .IntType ) :
294
293
return self ._emit_int_list (val )
295
294
296
- if elem_type == torch .FloatType :
295
+ if isinstance ( val_type , torch .FloatType ) :
297
296
return EValue (DoubleList (typing .cast (List [float ], val )))
298
297
299
- if elem_type == torch .TensorType :
298
+ if isinstance ( val_type , torch .TensorType ) :
300
299
values = []
301
300
for v in val :
302
301
assert isinstance (v , _AbstractValue )
@@ -308,10 +307,10 @@ def _emit_list(self, val: List[_Argument], val_type: _SchemaType) -> EValue:
308
307
values .append (v .id )
309
308
return EValue (TensorList (values ))
310
309
311
- if elem_type == torch .OptionalType :
310
+ if isinstance ( val_type , torch .OptionalType ) :
312
311
# refine further
313
- actual_type = typing . cast ( torch . OptionalType , val_type ) .getElementType ()
314
- if type (actual_type ) == torch .TensorType :
312
+ actual_type = val_type .getElementType ()
313
+ if isinstance (actual_type , torch .TensorType ) :
315
314
vals = []
316
315
for v in val :
317
316
if v is None :
@@ -437,9 +436,9 @@ def _constant_to_evalue( # noqa: C901
437
436
val_type = torch .ListType (
438
437
self ._get_list_tuple_jit_type (val ) # pyre-ignore
439
438
)
440
- if type (val_type ) == torch .OptionalType :
439
+ if isinstance (val_type , torch .OptionalType ) :
441
440
val_type = val_type .getElementType ()
442
- assert type (val_type ) == torch .ListType
441
+ assert isinstance (val_type , torch .ListType )
443
442
return self ._emit_list (
444
443
typing .cast (List [_Argument ], val ),
445
444
typing .cast (_SchemaType , val_type .getElementType ()),
0 commit comments