Skip to content

Commit 0d97472

Browse files
committed
Use rem_euclid for the stochastic rounding codec
1 parent 6ee6009 commit 0d97472

File tree

2 files changed

+57
-6
lines changed

2 files changed

+57
-6
lines changed

codecs/stochastic-rounding/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "numcodecs-stochastic-rounding"
3-
version = "0.1.0"
3+
version = "0.1.1"
44
edition = { workspace = true }
55
authors = { workspace = true }
66
repository = { workspace = true }

codecs/stochastic-rounding/src/lib.rs

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ where
221221
continue;
222222
}
223223

224-
let remainder = *x % precision.0;
224+
let remainder = x.rem_euclid(precision.0);
225225

226226
// compute the nearest multiples of precision based on the remainder
227227
// correct max 1 ULP rounding errors to ensure that the nearest
@@ -255,6 +255,10 @@ pub trait FloatExt: Float {
255255
/// Hash the binary representation of the floating point value
256256
fn hash_bits<H: Hasher>(self, hasher: &mut H);
257257

258+
/// Calculates the least nonnegative remainder of self (mod rhs).
259+
#[must_use]
260+
fn rem_euclid(self, rhs: Self) -> Self;
261+
258262
/// Returns the least number greater than `self`.
259263
#[must_use]
260264
fn next_up(self) -> Self;
@@ -271,6 +275,10 @@ impl FloatExt for f32 {
271275
hasher.write_u32(self.to_bits());
272276
}
273277

278+
fn rem_euclid(self, rhs: Self) -> Self {
279+
Self::rem_euclid(self, rhs)
280+
}
281+
274282
fn next_up(self) -> Self {
275283
Self::next_up(self)
276284
}
@@ -287,6 +295,10 @@ impl FloatExt for f64 {
287295
hasher.write_u64(self.to_bits());
288296
}
289297

298+
fn rem_euclid(self, rhs: Self) -> Self {
299+
Self::rem_euclid(self, rhs)
300+
}
301+
290302
fn next_up(self) -> Self {
291303
Self::next_up(self)
292304
}
@@ -336,28 +348,67 @@ mod tests {
336348
-f64::NAN,
337349
-f64::INFINITY,
338350
-42.0,
351+
-4.2,
339352
-0.0,
340353
0.0,
354+
4.2,
341355
42.0,
342356
f64::INFINITY,
343357
f64::NAN
344358
];
359+
let precision = 1.0;
345360

346-
let rounded = stochastic_rounding(data.view(), NonNegative(1.0), 42);
361+
let rounded = stochastic_rounding(data.view(), NonNegative(precision), 42);
347362

348363
for (d, r) in data.into_iter().zip(rounded) {
349-
assert!((r - d).abs() <= 1.0 || d.to_bits() == r.to_bits());
364+
assert!((r - d).abs() <= precision || d.to_bits() == r.to_bits());
350365
}
351366
}
352367

353368
#[test]
354369
fn round_rounding_errors() {
355370
let data = Array::from_iter(linspace(-100.0, 100.0, 3741));
371+
let precision = 0.1;
372+
373+
let rounded = stochastic_rounding(data.view(), NonNegative(precision), 42);
374+
375+
for (d, r) in data.into_iter().zip(rounded) {
376+
assert!((r - d).abs() <= precision);
377+
}
378+
}
379+
380+
#[test]
381+
fn test_rounding_bug() {
382+
let data = array![
383+
-1.23540_f32,
384+
-1.23539_f32,
385+
-1.23538_f32,
386+
-1.23537_f32,
387+
-1.23536_f32,
388+
-1.23535_f32,
389+
-1.23534_f32,
390+
-1.23533_f32,
391+
-1.23532_f32,
392+
-1.23531_f32,
393+
-1.23530_f32,
394+
1.23540_f32,
395+
1.23539_f32,
396+
1.23538_f32,
397+
1.23537_f32,
398+
1.23536_f32,
399+
1.23535_f32,
400+
1.23534_f32,
401+
1.23533_f32,
402+
1.23532_f32,
403+
1.23531_f32,
404+
1.23530_f32,
405+
];
406+
let precision = 0.00018_f32;
356407

357-
let rounded = stochastic_rounding(data.view(), NonNegative(0.1), 42);
408+
let rounded = stochastic_rounding(data.view(), NonNegative(precision), 42);
358409

359410
for (d, r) in data.into_iter().zip(rounded) {
360-
assert!((r - d).abs() <= 0.1);
411+
assert!((r - d).abs() <= precision);
361412
}
362413
}
363414
}

0 commit comments

Comments
 (0)