@@ -221,7 +221,7 @@ where
221221 continue ;
222222 }
223223
224- let remainder = * x % precision. 0 ;
224+ let remainder = x . rem_euclid ( precision. 0 ) ;
225225
226226 // compute the nearest multiples of precision based on the remainder
227227 // correct max 1 ULP rounding errors to ensure that the nearest
@@ -255,6 +255,10 @@ pub trait FloatExt: Float {
255255 /// Hash the binary representation of the floating point value
256256 fn hash_bits < H : Hasher > ( self , hasher : & mut H ) ;
257257
258+ /// Calculates the least nonnegative remainder of self (mod rhs).
259+ #[ must_use]
260+ fn rem_euclid ( self , rhs : Self ) -> Self ;
261+
258262 /// Returns the least number greater than `self`.
259263 #[ must_use]
260264 fn next_up ( self ) -> Self ;
@@ -271,6 +275,10 @@ impl FloatExt for f32 {
271275 hasher. write_u32 ( self . to_bits ( ) ) ;
272276 }
273277
278+ fn rem_euclid ( self , rhs : Self ) -> Self {
279+ Self :: rem_euclid ( self , rhs)
280+ }
281+
274282 fn next_up ( self ) -> Self {
275283 Self :: next_up ( self )
276284 }
@@ -287,6 +295,10 @@ impl FloatExt for f64 {
287295 hasher. write_u64 ( self . to_bits ( ) ) ;
288296 }
289297
298+ fn rem_euclid ( self , rhs : Self ) -> Self {
299+ Self :: rem_euclid ( self , rhs)
300+ }
301+
290302 fn next_up ( self ) -> Self {
291303 Self :: next_up ( self )
292304 }
@@ -336,28 +348,67 @@ mod tests {
336348 -f64 :: NAN ,
337349 -f64 :: INFINITY ,
338350 -42.0 ,
351+ -4.2 ,
339352 -0.0 ,
340353 0.0 ,
354+ 4.2 ,
341355 42.0 ,
342356 f64 :: INFINITY ,
343357 f64 :: NAN
344358 ] ;
359+ let precision = 1.0 ;
345360
346- let rounded = stochastic_rounding ( data. view ( ) , NonNegative ( 1.0 ) , 42 ) ;
361+ let rounded = stochastic_rounding ( data. view ( ) , NonNegative ( precision ) , 42 ) ;
347362
348363 for ( d, r) in data. into_iter ( ) . zip ( rounded) {
349- assert ! ( ( r - d) . abs( ) <= 1.0 || d. to_bits( ) == r. to_bits( ) ) ;
364+ assert ! ( ( r - d) . abs( ) <= precision || d. to_bits( ) == r. to_bits( ) ) ;
350365 }
351366 }
352367
353368 #[ test]
354369 fn round_rounding_errors ( ) {
355370 let data = Array :: from_iter ( linspace ( -100.0 , 100.0 , 3741 ) ) ;
371+ let precision = 0.1 ;
372+
373+ let rounded = stochastic_rounding ( data. view ( ) , NonNegative ( precision) , 42 ) ;
374+
375+ for ( d, r) in data. into_iter ( ) . zip ( rounded) {
376+ assert ! ( ( r - d) . abs( ) <= precision) ;
377+ }
378+ }
379+
380+ #[ test]
381+ fn test_rounding_bug ( ) {
382+ let data = array ! [
383+ -1.23540_f32 ,
384+ -1.23539_f32 ,
385+ -1.23538_f32 ,
386+ -1.23537_f32 ,
387+ -1.23536_f32 ,
388+ -1.23535_f32 ,
389+ -1.23534_f32 ,
390+ -1.23533_f32 ,
391+ -1.23532_f32 ,
392+ -1.23531_f32 ,
393+ -1.23530_f32 ,
394+ 1.23540_f32 ,
395+ 1.23539_f32 ,
396+ 1.23538_f32 ,
397+ 1.23537_f32 ,
398+ 1.23536_f32 ,
399+ 1.23535_f32 ,
400+ 1.23534_f32 ,
401+ 1.23533_f32 ,
402+ 1.23532_f32 ,
403+ 1.23531_f32 ,
404+ 1.23530_f32 ,
405+ ] ;
406+ let precision = 0.00018_f32 ;
356407
357- let rounded = stochastic_rounding ( data. view ( ) , NonNegative ( 0.1 ) , 42 ) ;
408+ let rounded = stochastic_rounding ( data. view ( ) , NonNegative ( precision ) , 42 ) ;
358409
359410 for ( d, r) in data. into_iter ( ) . zip ( rounded) {
360- assert ! ( ( r - d) . abs( ) <= 0.1 ) ;
411+ assert ! ( ( r - d) . abs( ) <= precision ) ;
361412 }
362413 }
363414}
0 commit comments