diff --git a/checkpoint/orbax/checkpoint/_src/arrays/fragments.py b/checkpoint/orbax/checkpoint/_src/arrays/fragments.py index 478ae20e5..7ccbbadf3 100644 --- a/checkpoint/orbax/checkpoint/_src/arrays/fragments.py +++ b/checkpoint/orbax/checkpoint/_src/arrays/fragments.py @@ -178,18 +178,26 @@ def slice( stop = out.stop[:] = np.minimum(out.stop, slice_shape) if not (start < stop).all(): return None - if (value := self.value) is None: - return out - else: - value_fragment = Fragment( - np_index=np.stack([ - np.maximum(self.start, np_index[:, 0]), - np.minimum(self.stop, np_index[:, 1]), - np_index[:, 2], - ], axis=1) - ).offset_by(-self.start) - out_value = value[value_fragment.index or ...] - return dataclasses.replace(out, value=out_value) + return dataclasses.replace( + out, value=self.slice_of_value(np_index) + ) if self.value is not None else out + + def slice_of_value( + self, + new_np_idx: NpIndex, + ) -> np.ndarray: + """Returns a slice of `value`.""" + start = self.start + stop = self.stop + # This is just a convenient way to construct the required tuple of slices. + f = Fragment( + np_index=np.stack([ + np.maximum(start, new_np_idx[:, 0]), + np.minimum(stop, new_np_idx[:, 1]), + new_np_idx[:, 2], + ], axis=1) + ).offset_by(-start) + return self.value[f.index or ...] @dataclasses.dataclass(frozen=True)