Skip to content

Commit

Permalink
decimal: improve Decimal.Pow accuracy
Browse files Browse the repository at this point in the history
  • Loading branch information
eapenkin authored Aug 23, 2023
1 parent ec5b253 commit f458f1e
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 31 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## [0.1.8] - 2023-08-23

### Changed

- Improved accuracy of `Decimal.Pow`.

## [0.1.7] - 2023-08-20

### Changed
Expand Down
46 changes: 46 additions & 0 deletions coefficient.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,41 @@ var spow10 = [...]*sint{
newSintFromPow10(62),
newSintFromPow10(63),
newSintFromPow10(64),
newSintFromPow10(65),
newSintFromPow10(66),
newSintFromPow10(67),
newSintFromPow10(68),
newSintFromPow10(69),
newSintFromPow10(70),
newSintFromPow10(71),
newSintFromPow10(72),
newSintFromPow10(73),
newSintFromPow10(74),
newSintFromPow10(75),
newSintFromPow10(76),
newSintFromPow10(77),
newSintFromPow10(78),
newSintFromPow10(79),
newSintFromPow10(80),
newSintFromPow10(81),
newSintFromPow10(82),
newSintFromPow10(83),
newSintFromPow10(84),
newSintFromPow10(85),
newSintFromPow10(86),
newSintFromPow10(87),
newSintFromPow10(88),
newSintFromPow10(89),
newSintFromPow10(90),
newSintFromPow10(91),
newSintFromPow10(92),
newSintFromPow10(93),
newSintFromPow10(94),
newSintFromPow10(95),
newSintFromPow10(96),
newSintFromPow10(97),
newSintFromPow10(98),
newSintFromPow10(99),
}

// newSintFromFint converts fint to *sint.
Expand Down Expand Up @@ -394,6 +429,17 @@ func (z *sint) fsa(shift int, y byte) {
z.add(z, newSintFromFint(fint(y)))
}

// rshDown (Right Shift) calculates x / 10^shift and rounds result towards 0.
func (z *sint) rshDown(x *sint, shift int) {
var y *sint
if shift < len(spow10) {
y = spow10[shift]
} else {
y = newSintFromPow10(shift)
}
z.quo(x, y)
}

// rshHalfEven (Right Shift) calculates x / 10^shift and
// rounds result using "half to even" rule.
func (z *sint) rshHalfEven(x *sint, shift int) {
Expand Down
77 changes: 77 additions & 0 deletions coefficient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,82 @@ func TestFint_hasPrec(t *testing.T) {
}
}

func TestSint_rshDown(t *testing.T) {
cases := []struct {
z string
shift int
want string
}{
// Rounding
{"1", 0, "1"},
{"20", 1, "2"},
{"18", 1, "1"},
{"15", 1, "1"},
{"12", 1, "1"},
{"10", 1, "1"},
{"8", 1, "0"},
{"5", 1, "0"},
{"2", 1, "0"},
{"9999999999999999999", 19, "0"},
{"9999999999999999999", 100, "0"},

// Large shifts
{"0", 17, "0"},
{"0", 18, "0"},
{"0", 19, "0"},
{"0", 20, "0"},
{"0", 21, "0"},
{"1", 17, "0"},
{"1", 18, "0"},
{"1", 19, "0"},
{"1", 20, "0"},
{"1", 21, "0"},
{"5000000000000000000", 17, "50"},
{"5000000000000000000", 18, "5"},
{"5000000000000000000", 19, "0"},
{"5000000000000000000", 20, "0"},
{"5000000000000000000", 21, "0"},
{"5000000000000000001", 17, "50"},
{"5000000000000000001", 18, "5"},
{"5000000000000000001", 19, "0"},
{"5000000000000000001", 20, "0"},
{"5000000000000000001", 21, "0"},
{"9999999999999999999", 17, "99"},
{"9999999999999999999", 18, "9"},
{"9999999999999999999", 19, "0"},
{"9999999999999999999", 20, "0"},
{"9999999999999999999", 21, "0"},
{"10000000000000000000", 17, "100"},
{"10000000000000000000", 18, "10"},
{"10000000000000000000", 19, "1"},
{"10000000000000000000", 20, "0"},
{"10000000000000000000", 21, "0"},
{"14999999999999999999", 17, "149"},
{"14999999999999999999", 18, "14"},
{"14999999999999999999", 19, "1"},
{"14999999999999999999", 20, "0"},
{"14999999999999999999", 21, "0"},
{"15000000000000000000", 17, "150"},
{"15000000000000000000", 18, "15"},
{"15000000000000000000", 19, "1"},
{"15000000000000000000", 20, "0"},
{"15000000000000000000", 21, "0"},
{"18446744073709551615", 17, "184"},
{"18446744073709551615", 18, "18"},
{"18446744073709551615", 19, "1"},
{"18446744073709551615", 20, "0"},
{"18446744073709551615", 21, "0"},
}
for _, tt := range cases {
got := mustParseSint(tt.z)
got.rshDown(got, tt.shift)
want := mustParseSint(tt.want)
if got.cmp(want) != 0 {
t.Errorf("%v.rshDown(%v) = %v, want %v", tt.z, tt.shift, got, want)
}
}
}

func TestSint_rshHalfEven(t *testing.T) {
cases := []struct {
z string
Expand All @@ -496,6 +572,7 @@ func TestSint_rshHalfEven(t *testing.T) {
{"5", 1, "0"},
{"2", 1, "0"},
{"9999999999999999999", 19, "1"},
{"9999999999999999999", 100, "0"},

// Large shifts
{"0", 17, "0"},
Expand Down
130 changes: 105 additions & 25 deletions decimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -1073,47 +1073,127 @@ func (d Decimal) mulSint(e Decimal, minScale int) (Decimal, error) {
// Pow returns the (possibly rounded) decimal raised to the given power.
//
// Pow returns an error if the integer part of the power has more than [MaxPrec] digits.
func (d Decimal) Pow(exp int) (Decimal, error) {
return d.PowExact(exp, 0)
func (d Decimal) Pow(power int) (Decimal, error) {
return d.PowExact(power, 0)
}

// PowExact is similar to [Decimal.Pow], but it allows you to specify the number of digits
// after the decimal point that should be considered significant.
// If any of the significant digits are lost during rounding, the method will return an error.
// This method is useful for financial calculations where the scale should be
// equal to or greater than the currency's scale.
func (d Decimal) PowExact(exp, scale int) (Decimal, error) {
e, err := d.powLoop(exp, scale)
if err != nil {
return Decimal{}, fmt.Errorf("%v^%v: %w", d, exp, err)
func (d Decimal) PowExact(power, scale int) (Decimal, error) {
if scale < 0 || scale > MaxScale {
return Decimal{}, fmt.Errorf("%v^%v: %w", d, power, errScaleRange)
}
// Trailing zeros (Workaround)
e = e.Trim(scale)
return e, nil
}

func (d Decimal) powLoop(exp, scale int) (Decimal, error) {
// Special case: power of 0
if exp == 0 {
return One, nil
// Special case: negative power
if power < 0 {
e, err := d.PowExact(-power, scale)
if err != nil {
return Decimal{}, fmt.Errorf("%v^(%v): %w", d, power, err)
}
e, err = One.QuoExact(e, scale)
if err != nil {
return Decimal{}, fmt.Errorf("%v^(%v): %w", d, power, err)
}
return e, nil
}

// General case
e, err := d.powLoop(exp/2, scale)
e, err := d.powFint(power, scale)
if err != nil {
return Decimal{}, err
}
e, err = e.MulExact(e, scale)
if err != nil {
return Decimal{}, err
e, err = d.powSint(power, scale)
if err != nil {
return Decimal{}, fmt.Errorf("%v^%v: %w", d, power, err)
}
}
if exp%2 == 0 {
return e, nil

return e, nil
}

func (d Decimal) powFint(power, minScale int) (Decimal, error) {
dneg, dcoef, dscale := d.IsNeg(), d.coef, d.Scale()
eneg, ecoef, escale := false, fint(1), 0

for power > 0 {
if power%2 == 1 {
power = power - 1

// Coefficient (Multiplication)
var ok bool
ecoef, ok = ecoef.mul(dcoef)
if !ok {
return Decimal{}, errDecimalOverflow
}

// Sign
eneg = eneg != dneg

// Scale
escale = escale + dscale
}
if power > 0 {
power = power / 2

// Coefficient (Squaring)
var ok bool
dcoef, ok = dcoef.mul(dcoef)
if !ok {
return Decimal{}, errDecimalOverflow
}

// Sign
dneg = false

// Scale
dscale = dscale * 2
}
}
if exp > 0 {
return e.MulExact(d, scale)
return newDecimalFromFint(eneg, ecoef, escale, minScale)
}

func (d Decimal) powSint(power, minScale int) (Decimal, error) {
dneg, dcoef, dscale := d.IsNeg(), newSintFromFint(d.coef), d.Scale()
eneg, ecoef, escale := false, newSintFromFint(1), 0

for power > 0 {
if power%2 == 1 {
power = power - 1

// Coefficient (Multiplication)
ecoef.mul(dcoef, ecoef)

// Sign
eneg = eneg != dneg

// Scale and truncation
escale = escale + dscale
if escale > 2*MaxScale {
shift := escale - 2*MaxScale
ecoef.rshDown(ecoef, shift)
escale = escale - shift
}
}
if power > 0 {
power = power / 2

// Coefficient (Squaring)
dcoef.mul(dcoef, dcoef)

// Sign
dneg = false

// Scale and truncation
dscale = dscale * 2
if dscale > 2*MaxScale {
shift := dscale - 2*MaxScale
dcoef.rshDown(dcoef, shift)
dscale = dscale - shift
}
}
}
return e.QuoExact(d, scale)
return newDecimalFromSint(eneg, ecoef, escale, minScale)
}

// Add returns the (possibly rounded) sum of decimals d and e.
Expand Down
15 changes: 9 additions & 6 deletions decimal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1866,11 +1866,11 @@ func TestDecimal_Pow(t *testing.T) {
{"0.5", 9, "0.001953125"},

// Interest accrual
{"1.1", 60, "304.4816395414180996"}, // no error
{"1.01", 600, "391.5833969993197775"}, // should be 391.5833969993197743, error in the last two digits
{"1.001", 6000, "402.2211245663552073"}, // should be 402.2211245663552923, error in the last three digits
{"1.0001", 60000, "403.3077910727185768"}, // should be 403.3077910727185433, error in the last three digits
{"1.00001", 600000, "403.4166908911752717"}, // should be 403.4166908911542153, error in the last six digits
{"1.1", 60, "304.4816395414180996"},
{"1.01", 600, "391.5833969993197743"},
{"1.001", 6000, "402.2211245663552923"},
{"1.0001", 60000, "403.3077910727185433"},
{"1.00001", 600000, "403.4166908911542153"},
}
for _, tt := range tests {
d := MustParse(tt.d)
Expand All @@ -1892,8 +1892,11 @@ func TestDecimal_Pow(t *testing.T) {
power, scale int
}{
"overflow 1": {"2", 64, 0},
"overflow 2": {"10", 19, 0},
"overflow 2": {"2", -64, 0},
"overflow 3": {"10", 19, 0},
"zero": {"0.1", -20, 0},
"scale 1": {"1", 1, MaxScale},
"scale 2": {"1", 1, -1},
}
for _, tt := range tests {
d := MustParse(tt.d)
Expand Down

0 comments on commit f458f1e

Please sign in to comment.