Skip to content

Commit 3f318f3

Browse files
: slice: Slice::view
Differential Revision: D75761780
1 parent 3f83778 commit 3f318f3

File tree

1 file changed

+108
-0
lines changed

1 file changed

+108
-0
lines changed

ndslice/src/slice.rs

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ pub enum SliceError {
3030

3131
#[error("value {value} not in slice")]
3232
ValueNotInSlice { value: usize },
33+
34+
#[error("incompatible view: {reason}")]
35+
IncompatibleView { reason: String },
3336
}
3437

3538
/// Slice is a compact representation of indices into the flat
@@ -281,7 +284,112 @@ impl Slice {
281284
mapper,
282285
}
283286
}
287+
288+
/// Returns a new [`Slice`] representing a view of this slice with
289+
/// the given shape.
290+
///
291+
/// This operation attempts to logically reinterpret the layout of
292+
/// the current slice using the provided `new_sizes`, without
293+
/// copying or reallocating memory.
294+
///
295+
/// # Requirements
296+
///
297+
/// - The total number of elements must match:
298+
/// `self.sizes().iter().product() == new_sizes.iter().product()`
299+
///
300+
/// - The physical layout of the current slice must be compatible
301+
/// with a standard row-major traversal of `new_sizes`. In other
302+
/// words, the linear memory addresses visited by iterating over
303+
/// the view must all be valid and reachable in the base slice.
304+
///
305+
/// If these conditions are not met, an error is returned.
306+
///
307+
/// # Errors
308+
///
309+
/// Returns [`SliceError::IncompatibleView`] if:
310+
/// - The element count does not match
311+
/// - The view's origin is unreachable
312+
/// - Any address implied by the view is not accessible in the
313+
/// base layout
314+
///
315+
/// # Examples
316+
///
317+
/// ```rust
318+
/// let base = Slice::new_row_major(&[2, 3, 4]);
319+
/// let reshaped = base.view(&[4, 3, 2])?;
320+
/// ```
321+
///
322+
/// # Notes
323+
///
324+
/// This is a conservative operation: it only succeeds when the
325+
/// proposed view is fully compatible with the base layout,
326+
/// ensuring safety without memory aliasing.
327+
///
328+
/// Use this when reinterpreting shape while preserving layout
329+
/// guarantees.
330+
pub fn view(&self, new_sizes: &[usize]) -> Result<Slice, SliceError> {
331+
let view_elems: usize = new_sizes.iter().product();
332+
let base_elems: usize = self.sizes().iter().product();
333+
334+
if view_elems != base_elems {
335+
return Err(SliceError::IncompatibleView {
336+
reason: format!(
337+
"element count mismatch: base has {}, view wants {}",
338+
base_elems, view_elems
339+
),
340+
});
341+
}
342+
343+
// Compute row-major strides
344+
let mut new_strides = vec![1; new_sizes.len()];
345+
for i in (0..new_sizes.len().saturating_sub(1)).rev() {
346+
new_strides[i] = new_strides[i + 1] * new_sizes[i + 1];
347+
}
348+
349+
// Attempt to find base offset from origin
350+
let origin = vec![0; new_sizes.len()];
351+
let offset = self
352+
.location(&origin)
353+
.map_err(|_| SliceError::IncompatibleView {
354+
reason: "could not compute origin offset in base".into(),
355+
})?;
356+
357+
// Validate that every address in the new view maps to a valid
358+
// coordinate in base
359+
let mut coord = vec![0; new_sizes.len()];
360+
for _ in 0..view_elems {
361+
// Compute offset of coord in view
362+
let offset_in_view = offset
363+
+ coord
364+
.iter()
365+
.zip(&new_strides)
366+
.map(|(i, s)| i * s)
367+
.sum::<usize>();
368+
369+
if self.coordinates(offset_in_view).is_ok() {
370+
return Err(SliceError::IncompatibleView {
371+
reason: format!("offset {} not reachable in base", offset_in_view),
372+
});
373+
}
374+
375+
// Increment coordinate
376+
for j in (0..coord.len()).rev() {
377+
coord[j] += 1;
378+
if coord[j] < new_sizes[j] {
379+
break;
380+
}
381+
coord[j] = 0;
382+
}
383+
}
384+
385+
Ok(Slice {
386+
offset,
387+
sizes: new_sizes.to_vec(),
388+
strides: new_strides,
389+
})
390+
}
284391
}
392+
285393
impl std::fmt::Display for Slice {
286394
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
287395
write!(f, "{:?}", self)

0 commit comments

Comments
 (0)