Skip to content

Commit

Permalink
Merge pull request #1315 from yotamofek/owned-view-iter
Browse files Browse the repository at this point in the history
Allow creating matrix iter with an owned view
  • Loading branch information
sebcrozet authored Nov 12, 2023
2 parents 83eccc6 + 1195ead commit a91e3b0
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 84 deletions.
24 changes: 24 additions & 0 deletions src/base/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,18 @@ impl<'a, T: Scalar, R: Dim, C: Dim, S: RawStorage<T, R, C>> IntoIterator
}
}

impl<'a, T: Scalar, R: Dim, C: Dim, RStride: Dim, CStride: Dim> IntoIterator
for Matrix<T, R, C, ViewStorage<'a, T, R, C, RStride, CStride>>
{
type Item = &'a T;
type IntoIter = MatrixIter<'a, T, R, C, ViewStorage<'a, T, R, C, RStride, CStride>>;

#[inline]
fn into_iter(self) -> Self::IntoIter {
MatrixIter::new_owned(self.data)
}
}

impl<'a, T: Scalar, R: Dim, C: Dim, S: RawStorageMut<T, R, C>> IntoIterator
for &'a mut Matrix<T, R, C, S>
{
Expand All @@ -110,6 +122,18 @@ impl<'a, T: Scalar, R: Dim, C: Dim, S: RawStorageMut<T, R, C>> IntoIterator
}
}

impl<'a, T: Scalar, R: Dim, C: Dim, RStride: Dim, CStride: Dim> IntoIterator
for Matrix<T, R, C, ViewStorageMut<'a, T, R, C, RStride, CStride>>
{
type Item = &'a mut T;
type IntoIter = MatrixIterMut<'a, T, R, C, ViewStorageMut<'a, T, R, C, RStride, CStride>>;

#[inline]
fn into_iter(self) -> Self::IntoIter {
MatrixIterMut::new_owned_mut(self.data)
}
}

impl<T: Scalar, const D: usize> From<[T; D]> for SVector<T, D> {
#[inline]
fn from(arr: [T; D]) -> Self {
Expand Down
152 changes: 121 additions & 31 deletions src/base/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,29 @@ use std::mem;

use crate::base::dimension::{Dim, U1};
use crate::base::storage::{RawStorage, RawStorageMut};
use crate::base::{Matrix, MatrixView, MatrixViewMut, Scalar};
use crate::base::{Matrix, MatrixView, MatrixViewMut, Scalar, ViewStorage, ViewStorageMut};

#[derive(Clone, Debug)]
struct RawIter<Ptr, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> {
ptr: Ptr,
inner_ptr: Ptr,
inner_end: Ptr,
size: usize,
strides: (RStride, CStride),
_phantoms: PhantomData<(fn() -> T, R, C)>,
}

macro_rules! iterator {
(struct $Name:ident for $Storage:ident.$ptr: ident -> $Ptr:ty, $Ref:ty, $SRef: ty, $($derives:ident),* $(,)?) => {
/// An iterator through a dense matrix with arbitrary strides matrix.
#[derive($($derives),*)]
pub struct $Name<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> {
ptr: $Ptr,
inner_ptr: $Ptr,
inner_end: $Ptr,
size: usize, // We can't use an end pointer here because a stride might be zero.
strides: (S::RStride, S::CStride),
_phantoms: PhantomData<($Ref, R, C, S)>,
}

// TODO: we need to specialize for the case where the matrix storage is owned (in which
// case the iterator is trivial because it does not have any stride).
impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> $Name<'a, T, R, C, S> {
impl<T, R: Dim, C: Dim, RStride: Dim, CStride: Dim>
RawIter<$Ptr, T, R, C, RStride, CStride>
{
/// Creates a new iterator for the given matrix storage.
pub fn new(storage: $SRef) -> $Name<'a, T, R, C, S> {
fn new<'a, S: $Storage<T, R, C, RStride = RStride, CStride = CStride>>(
storage: $SRef,
) -> Self {
let shape = storage.shape();
let strides = storage.strides();
let inner_offset = shape.0.value() * strides.0.value();
Expand All @@ -55,7 +58,7 @@ macro_rules! iterator {
unsafe { ptr.add(inner_offset) }
};

$Name {
RawIter {
ptr,
inner_ptr: ptr,
inner_end,
Expand All @@ -66,11 +69,13 @@ macro_rules! iterator {
}
}

impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> Iterator for $Name<'a, T, R, C, S> {
type Item = $Ref;
impl<T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> Iterator
for RawIter<$Ptr, T, R, C, RStride, CStride>
{
type Item = $Ptr;

#[inline]
fn next(&mut self) -> Option<$Ref> {
fn next(&mut self) -> Option<Self::Item> {
unsafe {
if self.size == 0 {
None
Expand Down Expand Up @@ -102,10 +107,7 @@ macro_rules! iterator {
self.ptr = self.ptr.add(stride);
}

// We want either `& *last` or `&mut *last` here, depending
// on the mutability of `$Ref`.
#[allow(clippy::transmute_ptr_to_ref)]
Some(mem::transmute(old))
Some(old)
}
}
}
Expand All @@ -121,11 +123,11 @@ macro_rules! iterator {
}
}

impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> DoubleEndedIterator
for $Name<'a, T, R, C, S>
impl<T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> DoubleEndedIterator
for RawIter<$Ptr, T, R, C, RStride, CStride>
{
#[inline]
fn next_back(&mut self) -> Option<$Ref> {
fn next_back(&mut self) -> Option<Self::Item> {
unsafe {
if self.size == 0 {
None
Expand All @@ -152,24 +154,88 @@ macro_rules! iterator {
.ptr
.add((outer_remaining * outer_stride + inner_remaining * inner_stride));

// We want either `& *last` or `&mut *last` here, depending
// on the mutability of `$Ref`.
#[allow(clippy::transmute_ptr_to_ref)]
Some(mem::transmute(last))
Some(last)
}
}
}
}

impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> ExactSizeIterator
for $Name<'a, T, R, C, S>
impl<T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> ExactSizeIterator
for RawIter<$Ptr, T, R, C, RStride, CStride>
{
#[inline]
fn len(&self) -> usize {
self.size
}
}

impl<T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> FusedIterator
for RawIter<$Ptr, T, R, C, RStride, CStride>
{
}

/// An iterator through a dense matrix with arbitrary strides matrix.
#[derive($($derives),*)]
pub struct $Name<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> {
inner: RawIter<$Ptr, T, R, C, S::RStride, S::CStride>,
_marker: PhantomData<$Ref>,
}

impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> $Name<'a, T, R, C, S> {
/// Creates a new iterator for the given matrix storage.
pub fn new(storage: $SRef) -> Self {
Self {
inner: RawIter::<$Ptr, T, R, C, S::RStride, S::CStride>::new(storage),
_marker: PhantomData,
}
}
}

impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> Iterator for $Name<'a, T, R, C, S> {
type Item = $Ref;

#[inline(always)]
fn next(&mut self) -> Option<Self::Item> {
// We want either `& *last` or `&mut *last` here, depending
// on the mutability of `$Ref`.
#[allow(clippy::transmute_ptr_to_ref)]
self.inner.next().map(|ptr| unsafe { mem::transmute(ptr) })
}

#[inline(always)]
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}

#[inline(always)]
fn count(self) -> usize {
self.inner.count()
}
}

impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> DoubleEndedIterator
for $Name<'a, T, R, C, S>
{
#[inline(always)]
fn next_back(&mut self) -> Option<Self::Item> {
// We want either `& *last` or `&mut *last` here, depending
// on the mutability of `$Ref`.
#[allow(clippy::transmute_ptr_to_ref)]
self.inner
.next_back()
.map(|ptr| unsafe { mem::transmute(ptr) })
}
}

impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> ExactSizeIterator
for $Name<'a, T, R, C, S>
{
#[inline(always)]
fn len(&self) -> usize {
self.inner.len()
}
}

impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> FusedIterator
for $Name<'a, T, R, C, S>
{
Expand All @@ -180,6 +246,30 @@ macro_rules! iterator {
iterator!(struct MatrixIter for RawStorage.ptr -> *const T, &'a T, &'a S, Clone, Debug);
iterator!(struct MatrixIterMut for RawStorageMut.ptr_mut -> *mut T, &'a mut T, &'a mut S, Debug);

impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim>
MatrixIter<'a, T, R, C, ViewStorage<'a, T, R, C, RStride, CStride>>
{
/// Creates a new iterator for the given matrix storage view.
pub fn new_owned(storage: ViewStorage<'a, T, R, C, RStride, CStride>) -> Self {
Self {
inner: RawIter::<*const T, T, R, C, RStride, CStride>::new(&storage),
_marker: PhantomData,
}
}
}

impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim>
MatrixIterMut<'a, T, R, C, ViewStorageMut<'a, T, R, C, RStride, CStride>>
{
/// Creates a new iterator for the given matrix storage view.
pub fn new_owned_mut(mut storage: ViewStorageMut<'a, T, R, C, RStride, CStride>) -> Self {
Self {
inner: RawIter::<*mut T, T, R, C, RStride, CStride>::new(&mut storage),
_marker: PhantomData,
}
}
}

/*
*
* Row iterators.
Expand Down
Loading

0 comments on commit a91e3b0

Please sign in to comment.