Skip to content

Commit dfbf252

Browse files
: slice: Slice::view (#113)
Summary: implements `Slice::view(&[usize]) -> Result<Slice>` with validation logic mirroring `torch.Tensor.view`. requires element count match and verifies row-major view layout is reachable in the base slice. conservative, safe reshape for canonical slices. Differential Revision: D75761780
1 parent 3f83778 commit dfbf252

File tree

1 file changed

+128
-0
lines changed

1 file changed

+128
-0
lines changed

ndslice/src/slice.rs

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ pub enum SliceError {
3030

3131
#[error("value {value} not in slice")]
3232
ValueNotInSlice { value: usize },
33+
34+
#[error("incompatible view: {reason}")]
35+
IncompatibleView { reason: String },
3336
}
3437

3538
/// Slice is a compact representation of indices into the flat
@@ -281,7 +284,92 @@ impl Slice {
281284
mapper,
282285
}
283286
}
287+
288+
/// Returns a new [`Slice`] with the given shape by reinterpreting
289+
/// the layout of this slice.
290+
///
291+
/// Constructs a new shape with standard row-major strides, using
292+
/// the same base offset. Returns an error if the reshaped view
293+
/// would access coordinates not valid in the original slice.
294+
///
295+
/// # Requirements
296+
///
297+
/// - This slice must be contiguous and have offset == 0.
298+
/// - The number of elements must match:
299+
/// `self.sizes().iter().product() == new_sizes.iter().product()`
300+
/// - Each flat offset in the proposed view must be valid in `self`.
301+
///
302+
/// # Errors
303+
///
304+
/// Returns [`SliceError::IncompatibleView`] if:
305+
/// - The element count differs
306+
/// - The base offset is nonzero
307+
/// - Any offset in the view is not reachable in the original slice
308+
///
309+
/// # Example
310+
///
311+
/// ```rust
312+
/// use ndslice::Slice;
313+
/// let base = Slice::new_row_major(&[2, 3, 4]);
314+
/// let reshaped = base.view(&[6, 4]).unwrap();
315+
/// ```
316+
pub fn view(&self, new_sizes: &[usize]) -> Result<Slice, SliceError> {
317+
let view_elems: usize = new_sizes.iter().product();
318+
let base_elems: usize = self.sizes().iter().product();
319+
320+
if view_elems != base_elems {
321+
return Err(SliceError::IncompatibleView {
322+
reason: format!(
323+
"element count mismatch: base has {}, view wants {}",
324+
base_elems, view_elems
325+
),
326+
});
327+
}
328+
329+
// Compute row-major strides
330+
let mut new_strides = vec![1; new_sizes.len()];
331+
for i in (0..new_sizes.len().saturating_sub(1)).rev() {
332+
new_strides[i] = new_strides[i + 1] * new_sizes[i + 1];
333+
}
334+
335+
let offset = self.offset();
336+
337+
// Validate that every address in the new view maps to a valid
338+
// coordinate in base
339+
let mut coord = vec![0; new_sizes.len()];
340+
for _ in 0..view_elems {
341+
// Compute offset of coord in view
342+
let offset_in_view = offset
343+
+ coord
344+
.iter()
345+
.zip(&new_strides)
346+
.map(|(i, s)| i * s)
347+
.sum::<usize>();
348+
349+
if self.coordinates(offset_in_view).is_err() {
350+
return Err(SliceError::IncompatibleView {
351+
reason: format!("offset {} not reachable in base", offset_in_view),
352+
});
353+
}
354+
355+
// Increment coordinate
356+
for j in (0..coord.len()).rev() {
357+
coord[j] += 1;
358+
if coord[j] < new_sizes[j] {
359+
break;
360+
}
361+
coord[j] = 0;
362+
}
363+
}
364+
365+
Ok(Slice {
366+
offset,
367+
sizes: new_sizes.to_vec(),
368+
strides: new_strides,
369+
})
370+
}
284371
}
372+
285373
impl std::fmt::Display for Slice {
286374
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
287375
write!(f, "{:?}", self)
@@ -557,4 +645,44 @@ mod tests {
557645
assert_eq!(s.sizes(), &[4, 4, 4]);
558646
assert_eq!(s.strides(), &[16, 4, 1]);
559647
}
648+
649+
#[test]
650+
fn test_slice_view_smoke() {
651+
use crate::Slice;
652+
653+
let base = Slice::new_row_major([2, 3, 4]);
654+
655+
// Reshape: compatible shape and layout
656+
let view = base.view(&[6, 4]).unwrap();
657+
assert_eq!(view.sizes(), &[6, 4]);
658+
assert_eq!(view.offset(), 0);
659+
assert_eq!(view.strides(), &[4, 1]);
660+
assert_eq!(
661+
view.location(&[5, 3]).unwrap(),
662+
base.location(&[1, 2, 3]).unwrap()
663+
);
664+
665+
// Reshape: identity (should succeed)
666+
let view = base.view(&[2, 3, 4]).unwrap();
667+
assert_eq!(view.sizes(), base.sizes());
668+
assert_eq!(view.strides(), base.strides());
669+
670+
// Reshape: incompatible shape (wrong element count)
671+
let err = base.view(&[5, 4]);
672+
assert!(err.is_err());
673+
674+
// Reshape: incompatible layout (simulate select)
675+
let selected = Slice::new(1, vec![2, 3], vec![6, 1]).unwrap(); // not offset=0
676+
let err = selected.view(&[3, 2]);
677+
assert!(err.is_err());
678+
679+
// Reshape: flat 1D view
680+
let flat = base.view(&[24]).unwrap();
681+
assert_eq!(flat.sizes(), &[24]);
682+
assert_eq!(flat.strides(), &[1]);
683+
assert_eq!(
684+
flat.location(&[23]).unwrap(),
685+
base.location(&[1, 2, 3]).unwrap()
686+
);
687+
}
560688
}

0 commit comments

Comments
 (0)