@@ -58,4 +58,60 @@ where S: Data<Elem = A>
5858 self . dim = self . dim . remove_axis ( axis) ;
5959 self . strides = self . strides . remove_axis ( axis) ;
6060 }
61+
62+ /// Remove axes of length 1 and return the modified array.
63+ ///
64+ /// If the array has more the one dimension, the result array will always
65+ /// have at least one dimension, even if it has a length of 1.
66+ ///
67+ /// ```
68+ /// use ndarray::{arr1, arr2, arr3};
69+ ///
70+ /// let a = arr3(&[[[1, 2, 3]], [[4, 5, 6]]]).into_dyn();
71+ /// assert_eq!(a.shape(), &[2, 1, 3]);
72+ /// let b = a.squeeze();
73+ /// assert_eq!(b, arr2(&[[1, 2, 3], [4, 5, 6]]).into_dyn());
74+ /// assert_eq!(b.shape(), &[2, 3]);
75+ ///
76+ /// let c = arr2(&[[1]]).into_dyn();
77+ /// assert_eq!(c.shape(), &[1, 1]);
78+ /// let d = c.squeeze();
79+ /// assert_eq!(d, arr1(&[1]).into_dyn());
80+ /// assert_eq!(d.shape(), &[1]);
81+ /// ```
82+ #[ track_caller]
83+ pub fn squeeze ( self ) -> Self
84+ {
85+ let mut out = self ;
86+ for axis in ( 0 ..out. shape ( ) . len ( ) ) . rev ( ) {
87+ if out. shape ( ) [ axis] == 1 && out. shape ( ) . len ( ) > 1 {
88+ out = out. remove_axis ( Axis ( axis) ) ;
89+ }
90+ }
91+ out
92+ }
93+ }
94+
95+ #[ cfg( test) ]
96+ mod tests
97+ {
98+ use crate :: { arr1, arr2, arr3} ;
99+
100+ #[ test]
101+ fn test_squeeze ( )
102+ {
103+ let a = arr3 ( & [ [ [ 1 , 2 , 3 ] ] , [ [ 4 , 5 , 6 ] ] ] ) . into_dyn ( ) ;
104+ assert_eq ! ( a. shape( ) , & [ 2 , 1 , 3 ] ) ;
105+
106+ let b = a. squeeze ( ) ;
107+ assert_eq ! ( b, arr2( & [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] ] ) . into_dyn( ) ) ;
108+ assert_eq ! ( b. shape( ) , & [ 2 , 3 ] ) ;
109+
110+ let c = arr2 ( & [ [ 1 ] ] ) . into_dyn ( ) ;
111+ assert_eq ! ( c. shape( ) , & [ 1 , 1 ] ) ;
112+
113+ let d = c. squeeze ( ) ;
114+ assert_eq ! ( d, arr1( & [ 1 ] ) . into_dyn( ) ) ;
115+ assert_eq ! ( d. shape( ) , & [ 1 ] ) ;
116+ }
61117}
0 commit comments