diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 78cdd2e1f..db1a26aba 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -1066,7 +1066,7 @@ def _jt_flatten_spec(t: JaggedTensor, spec: TreeSpec) -> List[Optional[torch.Ten def _assert_tensor_has_no_elements_or_has_integers( tensor: Optional[torch.Tensor], tensor_name: str ) -> None: - if is_torchdynamo_compiling() or tensor is None: + if torch.compiler.is_compiling() or tensor is None: # Skipping the check tensor.numel() == 0 to not guard on pt2 symbolic shapes. # TODO(ivankobzarev): Use guard_size_oblivious to pass tensor.numel() == 0 once it is torch scriptable. return