Skip to content

Commit 72b0d09

Browse files
authored
Merge pull request #1386 from akern40/triangular
Adds `triu` and `tril` methods directly to ArrayBase
2 parents e734ce8 + 84f0c80 commit 72b0d09

File tree

2 files changed

+295
-0
lines changed

2 files changed

+295
-0
lines changed

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1616,3 +1616,6 @@ pub(crate) fn is_aligned<T>(ptr: *const T) -> bool
16161616
{
16171617
(ptr as usize) % ::std::mem::align_of::<T>() == 0
16181618
}
1619+
1620+
// Triangular constructors
1621+
mod tri;

src/tri.rs

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
// Copyright 2014-2024 bluss and ndarray developers.
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
9+
use core::cmp::{max, min};
10+
11+
use num_traits::Zero;
12+
13+
use crate::{dimension::is_layout_f, Array, ArrayBase, Axis, Data, Dimension, IntoDimension, Zip};
14+
15+
impl<S, A, D> ArrayBase<S, D>
16+
where
17+
S: Data<Elem = A>,
18+
D: Dimension,
19+
A: Clone + Zero,
20+
D::Smaller: Copy,
21+
{
22+
/// Upper triangular of an array.
23+
///
24+
/// Return a copy of the array with elements below the *k*-th diagonal zeroed.
25+
/// For arrays with `ndim` exceeding 2, `triu` will apply to the final two axes.
26+
/// For 0D and 1D arrays, `triu` will return an unchanged clone.
27+
///
28+
/// See also [`ArrayBase::tril`]
29+
///
30+
/// ```
31+
/// use ndarray::array;
32+
///
33+
/// let arr = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
34+
/// let res = arr.triu(0);
35+
/// assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]);
36+
/// ```
37+
pub fn triu(&self, k: isize) -> Array<A, D>
38+
{
39+
if self.ndim() <= 1 {
40+
return self.to_owned();
41+
}
42+
match is_layout_f(&self.dim, &self.strides) {
43+
true => {
44+
let n = self.ndim();
45+
let mut x = self.view();
46+
x.swap_axes(n - 2, n - 1);
47+
let mut tril = x.tril(-k);
48+
tril.swap_axes(n - 2, n - 1);
49+
50+
tril
51+
}
52+
false => {
53+
let mut res = Array::zeros(self.raw_dim());
54+
Zip::indexed(self.rows())
55+
.and(res.rows_mut())
56+
.for_each(|i, src, mut dst| {
57+
let row_num = i.into_dimension().last_elem();
58+
let lower = max(row_num as isize + k, 0);
59+
dst.slice_mut(s![lower..]).assign(&src.slice(s![lower..]));
60+
});
61+
62+
res
63+
}
64+
}
65+
}
66+
67+
/// Lower triangular of an array.
68+
///
69+
/// Return a copy of the array with elements above the *k*-th diagonal zeroed.
70+
/// For arrays with `ndim` exceeding 2, `tril` will apply to the final two axes.
71+
/// For 0D and 1D arrays, `tril` will return an unchanged clone.
72+
///
73+
/// See also [`ArrayBase::triu`]
74+
///
75+
/// ```
76+
/// use ndarray::array;
77+
///
78+
/// let arr = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
79+
/// let res = arr.tril(0);
80+
/// assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);
81+
/// ```
82+
pub fn tril(&self, k: isize) -> Array<A, D>
83+
{
84+
if self.ndim() <= 1 {
85+
return self.to_owned();
86+
}
87+
match is_layout_f(&self.dim, &self.strides) {
88+
true => {
89+
let n = self.ndim();
90+
let mut x = self.view();
91+
x.swap_axes(n - 2, n - 1);
92+
let mut tril = x.triu(-k);
93+
tril.swap_axes(n - 2, n - 1);
94+
95+
tril
96+
}
97+
false => {
98+
let mut res = Array::zeros(self.raw_dim());
99+
let ncols = self.len_of(Axis(self.ndim() - 1)) as isize;
100+
Zip::indexed(self.rows())
101+
.and(res.rows_mut())
102+
.for_each(|i, src, mut dst| {
103+
let row_num = i.into_dimension().last_elem();
104+
let upper = min(row_num as isize + k, ncols) + 1;
105+
dst.slice_mut(s![..upper]).assign(&src.slice(s![..upper]));
106+
});
107+
108+
res
109+
}
110+
}
111+
}
112+
}
113+
114+
#[cfg(test)]
115+
mod tests
116+
{
117+
use crate::{array, dimension, Array0, Array1, Array2, Array3, ShapeBuilder};
118+
use alloc::vec;
119+
120+
#[test]
121+
fn test_keep_order()
122+
{
123+
let x = Array2::<f64>::ones((3, 3).f());
124+
let res = x.triu(0);
125+
assert!(dimension::is_layout_f(&res.dim, &res.strides));
126+
127+
let res = x.tril(0);
128+
assert!(dimension::is_layout_f(&res.dim, &res.strides));
129+
}
130+
131+
#[test]
132+
fn test_0d()
133+
{
134+
let x = Array0::<f64>::ones(());
135+
let res = x.triu(0);
136+
assert_eq!(res, x);
137+
138+
let res = x.tril(0);
139+
assert_eq!(res, x);
140+
141+
let x = Array0::<f64>::ones(().f());
142+
let res = x.triu(0);
143+
assert_eq!(res, x);
144+
145+
let res = x.tril(0);
146+
assert_eq!(res, x);
147+
}
148+
149+
#[test]
150+
fn test_1d()
151+
{
152+
let x = array![1, 2, 3];
153+
let res = x.triu(0);
154+
assert_eq!(res, x);
155+
156+
let res = x.triu(0);
157+
assert_eq!(res, x);
158+
159+
let x = Array1::<f64>::ones(3.f());
160+
let res = x.triu(0);
161+
assert_eq!(res, x);
162+
163+
let res = x.triu(0);
164+
assert_eq!(res, x);
165+
}
166+
167+
#[test]
168+
fn test_2d()
169+
{
170+
let x = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
171+
172+
// Upper
173+
let res = x.triu(0);
174+
assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]);
175+
176+
// Lower
177+
let res = x.tril(0);
178+
assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);
179+
180+
let x = Array2::from_shape_vec((3, 3).f(), vec![1, 4, 7, 2, 5, 8, 3, 6, 9]).unwrap();
181+
182+
// Upper
183+
let res = x.triu(0);
184+
assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]);
185+
186+
// Lower
187+
let res = x.tril(0);
188+
assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);
189+
}
190+
191+
#[test]
192+
fn test_3d()
193+
{
194+
let x = array![
195+
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
196+
[[10, 11, 12], [13, 14, 15], [16, 17, 18]],
197+
[[19, 20, 21], [22, 23, 24], [25, 26, 27]]
198+
];
199+
200+
// Upper
201+
let res = x.triu(0);
202+
assert_eq!(
203+
res,
204+
array![
205+
[[1, 2, 3], [0, 5, 6], [0, 0, 9]],
206+
[[10, 11, 12], [0, 14, 15], [0, 0, 18]],
207+
[[19, 20, 21], [0, 23, 24], [0, 0, 27]]
208+
]
209+
);
210+
211+
// Lower
212+
let res = x.tril(0);
213+
assert_eq!(
214+
res,
215+
array![
216+
[[1, 0, 0], [4, 5, 0], [7, 8, 9]],
217+
[[10, 0, 0], [13, 14, 0], [16, 17, 18]],
218+
[[19, 0, 0], [22, 23, 0], [25, 26, 27]]
219+
]
220+
);
221+
222+
let x = Array3::from_shape_vec(
223+
(3, 3, 3).f(),
224+
vec![1, 10, 19, 4, 13, 22, 7, 16, 25, 2, 11, 20, 5, 14, 23, 8, 17, 26, 3, 12, 21, 6, 15, 24, 9, 18, 27],
225+
)
226+
.unwrap();
227+
228+
// Upper
229+
let res = x.triu(0);
230+
assert_eq!(
231+
res,
232+
array![
233+
[[1, 2, 3], [0, 5, 6], [0, 0, 9]],
234+
[[10, 11, 12], [0, 14, 15], [0, 0, 18]],
235+
[[19, 20, 21], [0, 23, 24], [0, 0, 27]]
236+
]
237+
);
238+
239+
// Lower
240+
let res = x.tril(0);
241+
assert_eq!(
242+
res,
243+
array![
244+
[[1, 0, 0], [4, 5, 0], [7, 8, 9]],
245+
[[10, 0, 0], [13, 14, 0], [16, 17, 18]],
246+
[[19, 0, 0], [22, 23, 0], [25, 26, 27]]
247+
]
248+
);
249+
}
250+
251+
#[test]
252+
fn test_off_axis()
253+
{
254+
let x = array![
255+
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
256+
[[10, 11, 12], [13, 14, 15], [16, 17, 18]],
257+
[[19, 20, 21], [22, 23, 24], [25, 26, 27]]
258+
];
259+
260+
let res = x.triu(1);
261+
assert_eq!(
262+
res,
263+
array![
264+
[[0, 2, 3], [0, 0, 6], [0, 0, 0]],
265+
[[0, 11, 12], [0, 0, 15], [0, 0, 0]],
266+
[[0, 20, 21], [0, 0, 24], [0, 0, 0]]
267+
]
268+
);
269+
270+
let res = x.triu(-1);
271+
assert_eq!(
272+
res,
273+
array![
274+
[[1, 2, 3], [4, 5, 6], [0, 8, 9]],
275+
[[10, 11, 12], [13, 14, 15], [0, 17, 18]],
276+
[[19, 20, 21], [22, 23, 24], [0, 26, 27]]
277+
]
278+
);
279+
}
280+
281+
#[test]
282+
fn test_odd_shape()
283+
{
284+
let x = array![[1, 2, 3], [4, 5, 6]];
285+
let res = x.triu(0);
286+
assert_eq!(res, array![[1, 2, 3], [0, 5, 6]]);
287+
288+
let x = array![[1, 2], [3, 4], [5, 6]];
289+
let res = x.triu(0);
290+
assert_eq!(res, array![[1, 2], [0, 4], [0, 0]]);
291+
}
292+
}

0 commit comments

Comments
 (0)