Skip to content

Refactor stride_per_key_per_rank to support torch.Tensor #2872

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ def _kjt_concat(
lengths=torch.cat(length_list, dim=0),
stride=stride,
stride_per_key_per_rank=(
stride_per_key_per_rank if variable_stride_per_key else None
torch.tensor(stride_per_key_per_rank) if variable_stride_per_key else None
),
length_per_key=length_per_key if has_length_per_key else None,
inverse_indices=(
Expand Down Expand Up @@ -1096,7 +1096,7 @@ def _maybe_compute_stride_kjt(
stride: Optional[int],
lengths: Optional[torch.Tensor],
offsets: Optional[torch.Tensor],
stride_per_key_per_rank: Optional[List[List[int]]],
stride_per_key_per_rank: Union[Optional[torch.Tensor], Optional[List[List[int]]]],
) -> int:
if stride is None:
if len(keys) == 0:
Expand Down Expand Up @@ -1668,7 +1668,7 @@ def _maybe_compute_lengths_offset_per_key(

def _maybe_compute_stride_per_key(
stride_per_key: Optional[List[int]],
stride_per_key_per_rank: Optional[List[List[int]]],
stride_per_key_per_rank: Union[Optional[torch.Tensor], Optional[List[List[int]]]],
stride: Optional[int],
keys: List[str],
) -> Optional[List[int]]:
Expand All @@ -1684,7 +1684,7 @@ def _maybe_compute_stride_per_key(

def _maybe_compute_variable_stride_per_key(
variable_stride_per_key: Optional[bool],
stride_per_key_per_rank: Optional[List[List[int]]],
stride_per_key_per_rank: Union[Optional[torch.Tensor], Optional[List[List[int]]]],
) -> bool:
if variable_stride_per_key is not None:
return variable_stride_per_key
Expand Down Expand Up @@ -1766,7 +1766,9 @@ def __init__(
lengths: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
stride: Optional[int] = None,
stride_per_key_per_rank: Optional[List[List[int]]] = None,
stride_per_key_per_rank: Union[
Optional[torch.Tensor], Optional[List[List[int]]]
] = None,
# Below exposed to ensure torch.script-able
stride_per_key: Optional[List[int]] = None,
length_per_key: Optional[List[int]] = None,
Expand All @@ -1788,9 +1790,9 @@ def __init__(
self._lengths: Optional[torch.Tensor] = lengths
self._offsets: Optional[torch.Tensor] = offsets
self._stride: Optional[int] = stride
self._stride_per_key_per_rank: Optional[List[List[int]]] = (
stride_per_key_per_rank
)
self._stride_per_key_per_rank: Union[
Optional[torch.Tensor], Optional[List[List[int]]]
] = stride_per_key_per_rank
self._stride_per_key: Optional[List[int]] = stride_per_key
self._length_per_key: Optional[List[int]] = length_per_key
self._offset_per_key: Optional[List[int]] = offset_per_key
Expand Down Expand Up @@ -1827,7 +1829,9 @@ def from_offsets_sync(
offsets: torch.Tensor,
weights: Optional[torch.Tensor] = None,
stride: Optional[int] = None,
stride_per_key_per_rank: Optional[List[List[int]]] = None,
stride_per_key_per_rank: Union[
Optional[torch.Tensor], Optional[List[List[int]]]
] = None,
inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None,
) -> "KeyedJaggedTensor":
"""
Expand All @@ -1840,7 +1844,7 @@ def from_offsets_sync(
weights (Optional[torch.Tensor]): if the values have weights. Tensor with the
same shape as values.
stride (Optional[int]): number of examples per batch.
stride_per_key_per_rank (Optional[List[List[int]]]): batch size
stride_per_key_per_rank (Union[Optional[torch.Tensor], Optional[List[List[int]]]]): batch size
(number of examples) per key per rank, with the outer list representing the
keys and the inner list representing the values.
inverse_indices (Optional[Tuple[List[str], torch.Tensor]]): inverse indices to
Expand All @@ -1867,7 +1871,9 @@ def from_lengths_sync(
lengths: torch.Tensor,
weights: Optional[torch.Tensor] = None,
stride: Optional[int] = None,
stride_per_key_per_rank: Optional[List[List[int]]] = None,
stride_per_key_per_rank: Union[
Optional[torch.Tensor], Optional[List[List[int]]]
] = None,
inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None,
) -> "KeyedJaggedTensor":
"""
Expand All @@ -1881,7 +1887,7 @@ def from_lengths_sync(
weights (Optional[torch.Tensor]): if the values have weights. Tensor with the
same shape as values.
stride (Optional[int]): number of examples per batch.
stride_per_key_per_rank (Optional[List[List[int]]]): batch size
stride_per_key_per_rank (Union[Optional[torch.Tensor], Optional[List[List[int]]]]): batch size
(number of examples) per key per rank, with the outer list representing the
keys and the inner list representing the values.
inverse_indices (Optional[Tuple[List[str], torch.Tensor]]): inverse indices to
Expand Down Expand Up @@ -2193,8 +2199,15 @@ def stride_per_key_per_rank(self) -> List[List[int]]:
Returns:
List[List[int]]: stride per key per rank of the KeyedJaggedTensor.
"""
if self._stride_per_key_per_rank is None:
return []

stride_per_key_per_rank = self._stride_per_key_per_rank
return stride_per_key_per_rank if stride_per_key_per_rank is not None else []
return (
stride_per_key_per_rank.tolist()
if isinstance(stride_per_key_per_rank, torch.Tensor)
else stride_per_key_per_rank
)

def variable_stride_per_key(self) -> bool:
"""
Expand Down
Loading