diff --git a/ndslice/src/layout.rs b/ndslice/src/layout.rs new file mode 100644 index 00000000..38b8a0eb --- /dev/null +++ b/ndslice/src/layout.rs @@ -0,0 +1,248 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +use crate::slice::Slice; +use crate::slice::SliceError; +use crate::view::View; + +mod sealed { + // Private trait — only types in this module can implement it + pub trait Sealed {} +} + +/// A trait for memory layouts that map multidimensional coordinates +/// (in `ℕⁿ`) to linear memory offsets (`ℕ¹`) via an affine +/// transformation. +/// +/// This abstraction describes how an `n`-dimensional shape is laid +/// out in memory using a strided affine map: +/// +/// ```text +/// offset_of(x) = offset + dot(strides, x) +/// ``` +/// +/// This corresponds to an affine function `ℕⁿ → ℕ¹`, where `x` is a +/// coordinate in logical space, `strides` encodes layout, and +/// `offset` is the base address. +/// +/// Implementors define how coordinates in `n`-dimensional space are +/// translated to flat memory locations, enabling support for +/// row-major, column-major, and custom layouts. +pub trait LayoutMap: sealed::Sealed { + /// The number of dimensions in the domain of the map. + fn rank(&self) -> usize; + + /// The shape of the domain (number of elements per dimension). + fn sizes(&self) -> &[usize]; + + /// Maps a multidimensional coordinate to a linear memory offset. + fn offset_of(&self, coord: &[usize]) -> Result; +} + +/// A trait for memory layouts that support inverse mapping from +/// linear offsets (`ℕ¹`) back to multidimensional coordinates (in +/// `ℕⁿ`). +/// +/// This defines the inverse of the affine layout transformation given +/// by [`LayoutMap::offset_of`], where an offset is mapped back to a +/// coordinate in logical space—if possible. +/// +/// Not all layouts are invertible: aliasing, gaps, or padding may +/// prevent a one-to-one correspondence between coordinates and +/// offsets. However, standard layouts like contiguous row-major or +/// column-major typically do support inversion. +/// +/// Implementors define how to reconstruct the coordinate `x ∈ ℕⁿ` for +/// a given offset `o ∈ ℕ`, or return `None` if no such coordinate +/// exists. +pub trait LayoutMapInverse: sealed::Sealed { + /// Computes the multidimensional coordinate for a given linear + /// offset, or returns `None` if the offset is out of bounds. + fn coord_of(&self, offset: usize) -> Option>; +} + +/// Extension trait for applying shape transformations to layout-aware types. +/// +/// This trait enables ergonomic, composable construction of [`View`]s +/// over types that implement [`LayoutMap`] — including [`Slice`] and +/// other layout-aware data structures. It supports zero-copy +/// reinterpretation of memory layout, subject to shape and stride +/// compatibility. +/// +/// # Purpose +/// +/// - Enables `.view(&[...]) -> Result` syntax +/// - Defers layout validation until explicitly finalized +/// - Facilitates transformation chaining (e.g., `view → transpose`) +/// using composable `View` operations +/// +/// # Requirements +/// +/// Only [`LayoutMap`] is required, which allows forward coordinate-to-offset +/// mapping. Inversion (via [`LayoutMapInverse`]) is **not** required to +/// construct a `View`, only to finalize it into a [`Slice`] (e.g., via +/// `View::into_slice()`). +/// +/// # Behavior +/// +/// Calling `.view(&sizes)`: +/// - Computes row-major strides over `sizes` +/// - Computes the offset of `[0, 0, ..., 0]` in the base layout +/// - Constructs a [`View`] with new shape and strides +/// - **Does not** validate full layout compatibility — that is +/// deferred +/// +/// # Example +/// +/// ```rust +/// use ndslice::Slice +/// use ndslice::layout::LayoutTransformExt; +/// +/// let base = Slice::new_row_major([2, 3]); // 2×3 row-major layout +/// let view = base.view(&[3, 2])?; // Valid reshape: 6 elements +/// ``` +/// +/// # Notes +/// +/// - If `sizes` do not multiply to the same number of elements as the +/// base, an error is returned. +/// - If the new origin offset is not reachable in the base layout, +/// the view is rejected. +/// +/// # See Also +/// +/// - [`View`] — Lazy layout reinterpretation +/// - [`View::new`] — Raw constructor +/// - [`LayoutMap`] — Forward affine mapping trait +pub trait LayoutTransformExt { + fn view(&self, sizes: &[usize]) -> Result, SliceError>; +} + +/// Blanket implementation of [`LayoutTransformExt`] for all types +/// that implement [`LayoutMap`]. +/// +/// This enables ergonomic access to `.view(...)` on any +/// layout-compatible type, such as [`Slice`], without modifying the +/// type itself. +/// +/// The returned [`View`] is a lightweight logical reinterpretation of +/// the layout using row-major semantics. It does not yet validate +/// layout compatibility; that logic will be implemented as part of +/// `View::into_slice()` in a future revision. +/// +/// # Example +/// +/// ```rust +/// use ndslice::Slice +/// use ndslice::layout::LayoutTransformExt; +/// +/// let base = Slice::new_row_major([2, 3]); +/// let view = base.view(&[3, 2])?; +/// ``` +impl LayoutTransformExt for T +where + T: LayoutMap, +{ + fn view(&self, sizes: &[usize]) -> Result, SliceError> { + View::new(self, sizes.to_vec()) + } +} + +impl sealed::Sealed for Slice {} + +impl LayoutMap for Slice { + fn rank(&self) -> usize { + self.sizes().len() + } + + fn sizes(&self) -> &[usize] { + self.sizes() + } + + fn offset_of(&self, coord: &[usize]) -> Result { + if coord.len() != self.rank() { + return Err(SliceError::InvalidDims { + expected: self.rank(), + got: coord.len(), + }); + } + + // Dot product ∑ᵢ (strideᵢ × coordᵢ) + let linear_offset = self + .strides() + .iter() + .zip(coord) + .map(|(s, i)| s * i) + .sum::(); + + Ok(self.offset() + linear_offset) + } +} + +impl LayoutMapInverse for Slice { + fn coord_of(&self, value: usize) -> Option> { + let mut pos = value.checked_sub(self.offset())?; + let mut result = vec![0; self.rank()]; + + let mut dims: Vec<_> = self + .strides() + .iter() + .zip(self.sizes().iter().enumerate()) + .collect(); + + dims.sort_by_key(|&(stride, _)| *stride); + + // Invert: offset = base + ∑ᵢ (strideᵢ × coordᵢ) + // Solve for coordᵢ by peeling off largest strides first: + // coordᵢ = ⌊pos / strideᵢ⌋ + // pos -= coordᵢ × strideᵢ + // If any coordᵢ ≥ sizeᵢ or pos ≠ 0 at the end, the offset is + // invalid. + for &(stride, (i, &size)) in dims.iter().rev() { + let index = if size > 1 { pos / stride } else { 0 }; + if index >= size { + return None; + } + result[i] = index; + pos -= index * stride; + } + + (pos == 0).then_some(result) + } +} + +impl<'a> sealed::Sealed for View<'a> {} + +impl<'a> LayoutMap for View<'a> { + fn rank(&self) -> usize { + self.sizes.len() + } + + fn sizes(&self) -> &[usize] { + &self.sizes + } + + fn offset_of(&self, coord: &[usize]) -> Result { + if coord.len() != self.sizes.len() { + return Err(SliceError::InvalidDims { + expected: self.sizes.len(), + got: coord.len(), + }); + } + + // Compute offset = base_offset + dot(strides, coord) + let offset = self + .strides + .iter() + .zip(coord.iter()) + .map(|(s, i)| s * i) + .sum::(); + + Ok(self.offset + offset) + } +} diff --git a/ndslice/src/lib.rs b/ndslice/src/lib.rs index 58ce786e..605fd81d 100644 --- a/ndslice/src/lib.rs +++ b/ndslice/src/lib.rs @@ -26,6 +26,18 @@ pub use slice::Slice; pub use slice::SliceError; pub use slice::SliceIterator; +/// Layout traits and types for mapping multidimensional coordinates +/// to linear memory. +pub mod layout; + +/// View-based layout reinterpretation for `Slice`, similar to +/// `torch.Tensor.view`. +/// +/// Provides the [`View`] type and [`Slice::view`] method, allowing +/// shape changes without copying when layouts are compatible. See +/// module docs in `view.rs` for details. +pub mod view; + /// Selection algebra for describing multidimensional mesh regions. pub mod selection; diff --git a/ndslice/src/shape.rs b/ndslice/src/shape.rs index 20f2fed6..5c046484 100644 --- a/ndslice/src/shape.rs +++ b/ndslice/src/shape.rs @@ -52,6 +52,9 @@ pub enum ShapeError { #[error(transparent)] SliceError(#[from] SliceError), + + #[error("rank mismatch: expected {expected}, got {actual}")] + RankMismatch { expected: usize, actual: usize }, } /// A shape is a [`Slice`] with labeled dimensions and a selection API. diff --git a/ndslice/src/slice.rs b/ndslice/src/slice.rs index c64b81b8..35af7d0b 100644 --- a/ndslice/src/slice.rs +++ b/ndslice/src/slice.rs @@ -9,6 +9,9 @@ use serde::Deserialize; use serde::Serialize; +use crate::layout::LayoutMap; +use crate::layout::LayoutMapInverse; + /// The type of error for slice operations. #[derive(Debug, thiserror::Error)] #[non_exhaustive] @@ -30,6 +33,9 @@ pub enum SliceError { #[error("value {value} not in slice")] ValueNotInSlice { value: usize }, + + #[error("incompatible view: {reason}")] + IncompatibleView { reason: String }, } /// Slice is a compact representation of indices into the flat @@ -169,49 +175,14 @@ impl Slice { /// Return the location of the provided coordinates. pub fn location(&self, coord: &[usize]) -> Result { - if coord.len() != self.sizes.len() { - return Err(SliceError::InvalidDims { - expected: self.sizes.len(), - got: coord.len(), - }); - } - Ok(self.offset - + coord - .iter() - .zip(&self.strides) - .map(|(pos, stride)| pos * stride) - .sum::()) + self.offset_of(coord) } /// Return the coordinates of the provided value in the n-d space of this /// Slice. pub fn coordinates(&self, value: usize) -> Result, SliceError> { - let mut pos = value - .checked_sub(self.offset) - .ok_or(SliceError::ValueNotInSlice { value })?; - let mut result = vec![0; self.sizes.len()]; - let mut sorted_info: Vec<_> = self - .strides - .iter() - .zip(self.sizes.iter().enumerate()) - .collect(); - sorted_info.sort_by_key(|&(stride, _)| *stride); - for &(stride, (i, &size)) in sorted_info.iter().rev() { - let (index, new_pos) = if size > 1 { - (pos / stride, pos % stride) - } else { - (0, pos) - }; - if index >= size { - return Err(SliceError::ValueNotInSlice { value }); - } - result[i] = index; - pos = new_pos; - } - if pos != 0 { - return Err(SliceError::ValueNotInSlice { value }); - } - Ok(result) + self.coord_of(value) + .ok_or(SliceError::ValueNotInSlice { value }) } /// Retrieve the underlying location of the provided slice index. diff --git a/ndslice/src/view.rs b/ndslice/src/view.rs new file mode 100644 index 00000000..4e75fd9e --- /dev/null +++ b/ndslice/src/view.rs @@ -0,0 +1,246 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +//! View planning and design +//! +//! This module implements `Slice::view(...)` with semantics analogous +//! to `torch.Tensor.view(...)`. The goal is to reinterpret the memory +//! layout of an existing `Slice` without copying, assuming the new +//! shape is element-count compatible and layout-compatible with the +//! base slice. +//! +//! # Objective +//! +//! Provide an API like: +//! +//! ```ignore +//! let v: View<'_> = slice.view(&[2, 3, 4])?; +//! let reshaped: Slice = v.into_slice()?; +//! ``` +//! +//! ## Requirements +//! +//! - The new shape must have the same number of elements as the base. +//! ✅ +//! - The new shape must be layout-compatible — i.e. its logical +//! traversal must match the base slice's physical memory order. ⏳ +//! (partially enforced) +//! - No memory copying or reallocation is performed. ✅ +//! - The returned `View` supports further transformations (e.g. +//! `transpose`, etc.) before being finalized as a `Slice`. ⏳ +//! +//! ## Stride Compatibility (Contiguity-like Condition) +//! +//! To match PyTorch semantics, the layout of the proposed view must +//! be compatible with the base slice's strides. This requires that +//! the dimensions of the view either: +//! +//! - Correspond directly to dimensions of the base, or +//! - Span across multiple base dimensions whose strides satisfy the +//! contiguity-like condition: +//! +//! ```text +//! ∀ i = d .. d+k−1: +//! stride[i] == stride[i+1] * size[i+1] +//! ``` +//! +//! This condition ensures the new view can be projected onto the base +//! memory without ambiguity or aliasing. If this fails, `view()` must +//! return an error. We currently do not support automatic copying to +//! make incompatible views possible. +//! +//! # Design +//! +//! We introduce a `View<'a>` type that holds: +//! +//! ```ignore +//! pub struct View<'a> { +//! base: &'a dyn LayoutMap, +//! offset: usize, +//! sizes: Vec, +//! strides: Vec, +//! } +//! ``` +//! +//! The `View` acts as a deferred layout reinterpretation over a base +//! `LayoutMap`. It allows chaining and validation without eagerly +//! materializing a new `Slice`. +//! +//! ## Responsibilities +//! +//! - ✅ `View::new(base, sizes)`: +//! - Computes offset from base +//! - Computes row-major strides for sizes +//! - Validates that total element count matches base +//! - Constructs a `View` (without validating layout yet) +//! +//! - ⏳ `View::validate_layout()`: +//! - Iterates over all coordinates in the view +//! - Maps each coordinate to a linear offset via the view +//! - Uses `base.coord_of(offset)` to check round-trip validity +//! - Ensures all addresses produced by the view are reachable in +//! the base +//! +//! - ⏳ `View::into_slice()`: +//! - Not yet implemented +//! - Will run `validate_layout()` +//! - Will return a new `Slice { offset, sizes, strides }` +//! +//! ## Slice API +//! +//! ✅ +//! ```ignore +//! impl Slice { +//! pub fn view(&self, new_shape: &[usize]) -> Result, SliceError> { +//! View::new(self, new_shape.to_vec()) +//! } +//! } +//! ``` +//! +//! ## Error Handling +//! +//! ✅ View construction and finalization may fail if the shape or +//! layout is incompatible with the base slice. To report these +//! failures, we added: +//! +//! ```ignore +//! #[derive(Error, Debug)] +//! pub enum SliceError { +//! ... +//! #[error("incompatible view: {reason}")] +//! IncompatibleView { reason: String }, +//! } +//! ``` +//! +//! Used to signal: +//! - Mismatched element count ✅ +//! - Unreachable origin offset ✅ +//! - Layout incompatibility ⏳ +//! +//! # Summary +//! +//! This design mirrors PyTorch’s `Tensor.view()` behavior while +//! embracing Rust’s type system and layout abstraction. The `View` +//! type is a pure, cheap, composable transformation that defers +//! validation and finalization until explicitly requested. +//! +//! ## Row-Major to Column-Major Conversion +//! +//! As a proof of concept for the generality of `View`, we implement a +//! transformation that reinterprets a row-major `Slice` as +//! column-major — and vice versa — via `View::transpose(...)`, by +//! modifying strides while preserving sizes and offset. +//! +//! For example: +//! +//! ```ignore +//! // Original row-major Slice: +//! sizes: [3, 4] +//! strides: [4, 1] +//! +//! // View as column-major (via transpose): +//! sizes: [4, 3] +//! strides: [1, 4] +//! ``` +use crate::layout::LayoutMap; +use crate::slice::SliceError; + +pub struct View<'a> { + pub base: &'a dyn LayoutMap, + pub offset: usize, + pub sizes: Vec, + pub strides: Vec, +} + +impl<'a> View<'a> { + /// Constructs a new `View` over an existing layout with the given + /// shape. + /// + /// This function creates a logical reinterpretation of the `base` + /// layout using a new shape and standard row-major strides. The + /// result is a lightweight, composable transformation that does + /// not copy or reallocate memory. + /// + /// # Invariants established by this constructor: + /// + /// - The new shape's element count matches that of the base: + /// `∏(sizes) == ∏(base.shape())` + /// + /// - The new view starts at coordinate `[0, 0, ..., 0]`, which is + /// guaranteed to map to a valid flat offset in the base layout + /// (`offset_of(origin)`). + /// + /// - The `strides` field defines a valid **row-major layout** + /// over `sizes`, such that `offset + dot(strides, coord)` + /// computes the flat offset of any coordinate in the view. + /// + /// # What is NOT yet validated: + /// + /// - This function does **not** check that the entire layout + /// defined by the view's shape and strides is compatible with + /// the base layout. + /// + /// - In particular, it does **not** verify that all coordinates + /// in the view map to addresses that are valid (reachable and + /// non-aliased) in the base. + /// + /// - It also does not validate that the stride pattern is + /// layout-compatible with the base's physical memory ordering + /// (e.g., contiguity conditions). + /// + /// # Why validation is deferred: + /// + /// This design mirrors PyTorch’s `Tensor.view()` behavior: + /// - `View::new(...)` is cheap and composable + /// - Full validation is performed only at finalization (e.g., in + /// `into_slice()`), after all transformations (e.g., + /// `.view().transpose().reshape()`) are complete. + /// + /// This enables flexible and efficient layout manipulation + /// without prematurely committing to a particular representation. + /// + /// # Errors + /// + /// Returns `SliceError::IncompatibleView` if: + /// - The total number of elements in `sizes` does not match the + /// base + /// - The origin offset `[0, 0, ..., 0]` is not reachable in the + /// base + pub fn new(base: &'a dyn LayoutMap, sizes: Vec) -> Result { + // Compute standard row-major strides. + let mut strides = vec![1; sizes.len()]; + for i in (0..sizes.len().saturating_sub(1)).rev() { + strides[i] = strides[i + 1] * sizes[i + 1]; + } + + let view_elem_count = sizes.iter().product::(); + let base_elem_count = base.sizes().iter().product::(); + if view_elem_count != base_elem_count { + return Err(SliceError::IncompatibleView { + reason: format!( + "element count mismatch: base has {}, view wants {}", + base_elem_count, view_elem_count + ), + }); + } + + let origin = vec![0; sizes.len()]; + let offset = base + .offset_of(&origin) + .map_err(|_e| SliceError::IncompatibleView { + reason: "could not compute origin offset in base layout".into(), + })?; + + Ok(Self { + base, + offset, + sizes, + strides, + }) + } +}