Skip to content

: add new module 'view' (#109) #112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 248 additions & 0 deletions ndslice/src/layout.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

use crate::slice::Slice;
use crate::slice::SliceError;
use crate::view::View;

mod sealed {
// Private trait — only types in this module can implement it
pub trait Sealed {}
}

/// A trait for memory layouts that map multidimensional coordinates
/// (in `ℕⁿ`) to linear memory offsets (`ℕ¹`) via an affine
/// transformation.
///
/// This abstraction describes how an `n`-dimensional shape is laid
/// out in memory using a strided affine map:
///
/// ```text
/// offset_of(x) = offset + dot(strides, x)
/// ```
///
/// This corresponds to an affine function `ℕⁿ → ℕ¹`, where `x` is a
/// coordinate in logical space, `strides` encodes layout, and
/// `offset` is the base address.
///
/// Implementors define how coordinates in `n`-dimensional space are
/// translated to flat memory locations, enabling support for
/// row-major, column-major, and custom layouts.
pub trait LayoutMap: sealed::Sealed {
/// The number of dimensions in the domain of the map.
fn rank(&self) -> usize;

/// The shape of the domain (number of elements per dimension).
fn sizes(&self) -> &[usize];

/// Maps a multidimensional coordinate to a linear memory offset.
fn offset_of(&self, coord: &[usize]) -> Result<usize, SliceError>;
}

/// A trait for memory layouts that support inverse mapping from
/// linear offsets (`ℕ¹`) back to multidimensional coordinates (in
/// `ℕⁿ`).
///
/// This defines the inverse of the affine layout transformation given
/// by [`LayoutMap::offset_of`], where an offset is mapped back to a
/// coordinate in logical space—if possible.
///
/// Not all layouts are invertible: aliasing, gaps, or padding may
/// prevent a one-to-one correspondence between coordinates and
/// offsets. However, standard layouts like contiguous row-major or
/// column-major typically do support inversion.
///
/// Implementors define how to reconstruct the coordinate `x ∈ ℕⁿ` for
/// a given offset `o ∈ ℕ`, or return `None` if no such coordinate
/// exists.
pub trait LayoutMapInverse: sealed::Sealed {
/// Computes the multidimensional coordinate for a given linear
/// offset, or returns `None` if the offset is out of bounds.
fn coord_of(&self, offset: usize) -> Option<Vec<usize>>;
}

/// Extension trait for applying shape transformations to layout-aware types.
///
/// This trait enables ergonomic, composable construction of [`View`]s
/// over types that implement [`LayoutMap`] — including [`Slice`] and
/// other layout-aware data structures. It supports zero-copy
/// reinterpretation of memory layout, subject to shape and stride
/// compatibility.
///
/// # Purpose
///
/// - Enables `.view(&[...]) -> Result<View>` syntax
/// - Defers layout validation until explicitly finalized
/// - Facilitates transformation chaining (e.g., `view → transpose`)
/// using composable `View` operations
///
/// # Requirements
///
/// Only [`LayoutMap`] is required, which allows forward coordinate-to-offset
/// mapping. Inversion (via [`LayoutMapInverse`]) is **not** required to
/// construct a `View`, only to finalize it into a [`Slice`] (e.g., via
/// `View::into_slice()`).
///
/// # Behavior
///
/// Calling `.view(&sizes)`:
/// - Computes row-major strides over `sizes`
/// - Computes the offset of `[0, 0, ..., 0]` in the base layout
/// - Constructs a [`View`] with new shape and strides
/// - **Does not** validate full layout compatibility — that is
/// deferred
///
/// # Example
///
/// ```rust
/// use ndslice::Slice
/// use ndslice::layout::LayoutTransformExt;
///
/// let base = Slice::new_row_major([2, 3]); // 2×3 row-major layout
/// let view = base.view(&[3, 2])?; // Valid reshape: 6 elements
/// ```
///
/// # Notes
///
/// - If `sizes` do not multiply to the same number of elements as the
/// base, an error is returned.
/// - If the new origin offset is not reachable in the base layout,
/// the view is rejected.
///
/// # See Also
///
/// - [`View`] — Lazy layout reinterpretation
/// - [`View::new`] — Raw constructor
/// - [`LayoutMap`] — Forward affine mapping trait
pub trait LayoutTransformExt {
fn view(&self, sizes: &[usize]) -> Result<View<'_>, SliceError>;
}

/// Blanket implementation of [`LayoutTransformExt`] for all types
/// that implement [`LayoutMap`].
///
/// This enables ergonomic access to `.view(...)` on any
/// layout-compatible type, such as [`Slice`], without modifying the
/// type itself.
///
/// The returned [`View`] is a lightweight logical reinterpretation of
/// the layout using row-major semantics. It does not yet validate
/// layout compatibility; that logic will be implemented as part of
/// `View::into_slice()` in a future revision.
///
/// # Example
///
/// ```rust
/// use ndslice::Slice
/// use ndslice::layout::LayoutTransformExt;
///
/// let base = Slice::new_row_major([2, 3]);
/// let view = base.view(&[3, 2])?;
/// ```
impl<T> LayoutTransformExt for T
where
T: LayoutMap,
{
fn view(&self, sizes: &[usize]) -> Result<View<'_>, SliceError> {
View::new(self, sizes.to_vec())
}
}

impl sealed::Sealed for Slice {}

impl LayoutMap for Slice {
fn rank(&self) -> usize {
self.sizes().len()
}

fn sizes(&self) -> &[usize] {
self.sizes()
}

fn offset_of(&self, coord: &[usize]) -> Result<usize, SliceError> {
if coord.len() != self.rank() {
return Err(SliceError::InvalidDims {
expected: self.rank(),
got: coord.len(),
});
}

// Dot product ∑ᵢ (strideᵢ × coordᵢ)
let linear_offset = self
.strides()
.iter()
.zip(coord)
.map(|(s, i)| s * i)
.sum::<usize>();

Ok(self.offset() + linear_offset)
}
}

impl LayoutMapInverse for Slice {
fn coord_of(&self, value: usize) -> Option<Vec<usize>> {
let mut pos = value.checked_sub(self.offset())?;
let mut result = vec![0; self.rank()];

let mut dims: Vec<_> = self
.strides()
.iter()
.zip(self.sizes().iter().enumerate())
.collect();

dims.sort_by_key(|&(stride, _)| *stride);

// Invert: offset = base + ∑ᵢ (strideᵢ × coordᵢ)
// Solve for coordᵢ by peeling off largest strides first:
// coordᵢ = ⌊pos / strideᵢ⌋
// pos -= coordᵢ × strideᵢ
// If any coordᵢ ≥ sizeᵢ or pos ≠ 0 at the end, the offset is
// invalid.
for &(stride, (i, &size)) in dims.iter().rev() {
let index = if size > 1 { pos / stride } else { 0 };
if index >= size {
return None;
}
result[i] = index;
pos -= index * stride;
}

(pos == 0).then_some(result)
}
}

impl<'a> sealed::Sealed for View<'a> {}

impl<'a> LayoutMap for View<'a> {
fn rank(&self) -> usize {
self.sizes.len()
}

fn sizes(&self) -> &[usize] {
&self.sizes
}

fn offset_of(&self, coord: &[usize]) -> Result<usize, SliceError> {
if coord.len() != self.sizes.len() {
return Err(SliceError::InvalidDims {
expected: self.sizes.len(),
got: coord.len(),
});
}

// Compute offset = base_offset + dot(strides, coord)
let offset = self
.strides
.iter()
.zip(coord.iter())
.map(|(s, i)| s * i)
.sum::<usize>();

Ok(self.offset + offset)
}
}
12 changes: 12 additions & 0 deletions ndslice/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ pub use slice::Slice;
pub use slice::SliceError;
pub use slice::SliceIterator;

/// Layout traits and types for mapping multidimensional coordinates
/// to linear memory.
pub mod layout;

/// View-based layout reinterpretation for `Slice`, similar to
/// `torch.Tensor.view`.
///
/// Provides the [`View`] type and [`Slice::view`] method, allowing
/// shape changes without copying when layouts are compatible. See
/// module docs in `view.rs` for details.
pub mod view;

/// Selection algebra for describing multidimensional mesh regions.
pub mod selection;

Expand Down
3 changes: 3 additions & 0 deletions ndslice/src/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ pub enum ShapeError {

#[error(transparent)]
SliceError(#[from] SliceError),

#[error("rank mismatch: expected {expected}, got {actual}")]
RankMismatch { expected: usize, actual: usize },
}

/// A shape is a [`Slice`] with labeled dimensions and a selection API.
Expand Down
47 changes: 9 additions & 38 deletions ndslice/src/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
use serde::Deserialize;
use serde::Serialize;

use crate::layout::LayoutMap;
use crate::layout::LayoutMapInverse;

/// The type of error for slice operations.
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
Expand All @@ -30,6 +33,9 @@ pub enum SliceError {

#[error("value {value} not in slice")]
ValueNotInSlice { value: usize },

#[error("incompatible view: {reason}")]
IncompatibleView { reason: String },
}

/// Slice is a compact representation of indices into the flat
Expand Down Expand Up @@ -169,49 +175,14 @@ impl Slice {

/// Return the location of the provided coordinates.
pub fn location(&self, coord: &[usize]) -> Result<usize, SliceError> {
if coord.len() != self.sizes.len() {
return Err(SliceError::InvalidDims {
expected: self.sizes.len(),
got: coord.len(),
});
}
Ok(self.offset
+ coord
.iter()
.zip(&self.strides)
.map(|(pos, stride)| pos * stride)
.sum::<usize>())
self.offset_of(coord)
}

/// Return the coordinates of the provided value in the n-d space of this
/// Slice.
pub fn coordinates(&self, value: usize) -> Result<Vec<usize>, SliceError> {
let mut pos = value
.checked_sub(self.offset)
.ok_or(SliceError::ValueNotInSlice { value })?;
let mut result = vec![0; self.sizes.len()];
let mut sorted_info: Vec<_> = self
.strides
.iter()
.zip(self.sizes.iter().enumerate())
.collect();
sorted_info.sort_by_key(|&(stride, _)| *stride);
for &(stride, (i, &size)) in sorted_info.iter().rev() {
let (index, new_pos) = if size > 1 {
(pos / stride, pos % stride)
} else {
(0, pos)
};
if index >= size {
return Err(SliceError::ValueNotInSlice { value });
}
result[i] = index;
pos = new_pos;
}
if pos != 0 {
return Err(SliceError::ValueNotInSlice { value });
}
Ok(result)
self.coord_of(value)
.ok_or(SliceError::ValueNotInSlice { value })
}

/// Retrieve the underlying location of the provided slice index.
Expand Down
Loading