From 0beed27df3c70e46ee69cbe135f12bc6a02a990f Mon Sep 17 00:00:00 2001 From: eapenkin Date: Sat, 22 Jun 2024 19:00:53 +0400 Subject: [PATCH] decimal: implement Sqrt and SqrtExact methods --- .github/workflows/go.yml | 3 + CHANGELOG.md | 6 ++ README.md | 6 +- decimal.go | 75 +++++++++++++++++++- decimal_test.go | 139 +++++++++++++++++++++++++++++++++++++ doc.go | 9 ++- doc_test.go | 143 +++++++++++++++++++++++++++++++++++++++ 7 files changed, 374 insertions(+), 7 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 3990652..0b95668 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -80,6 +80,9 @@ jobs: - name: Run fuzzing for fused multiply-addition run: go test -fuzztime 60s -fuzz ^FuzzDecimal_FMA$ + - name: Run fuzzing for square root + run: go test -fuzztime 20s -fuzz ^FuzzDecimal_Sqrt$ + - name: Run fuzzing for division run: go test -fuzztime 20s -fuzz ^FuzzDecimal_Quo$ diff --git a/CHANGELOG.md b/CHANGELOG.md index a2a0c44..3bd4b16 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [0.1.28] - 2024-06-22 + +### Added + +- Implemented `Decimal.Sqrt`, `Decimal.SqrtExact`. + ## [0.1.27] - 2024-05-19 ### Changed diff --git a/README.md b/README.md index 5524d2a..f1fe8a7 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,8 @@ func main() { fmt.Println(d.Mul(e)) // 8 * 12.5 fmt.Println(d.FMA(e, f)) // 8 * 12.5 + 2.567 - fmt.Println(d.Pow(2)) // 8 ^ 2 + fmt.Println(d.Pow(2)) // 8² + fmt.Println(d.Sqrt()) // √8 fmt.Println(d.Quo(e)) // 8 ÷ 12.5 fmt.Println(d.QuoRem(e)) // 8 div 12.5, 8 mod 12.5 @@ -127,10 +128,11 @@ cpu: AMD Ryzen 7 3700C with Radeon Vega Mobile Gfx | Add | 5 + 6 | 16.06n | 74.88n | 140.90n | +366.22% | +777.33% | | Mul | 2 * 3 | 16.93n | 62.20n | 146.00n | +267.40% | +762.37% | | QuoExact | 2 ÷ 4 | 59.52n | 176.95n | 657.40n | +197.30% | +1004.50% | -| QuoInfinite | 2 ÷ 3 | 391.60n | 976.8n | 2962.50n | +149.39% | +656.42% | +| QuoInfinite | 2 ÷ 3 | 391.60n | 976.80n | 2962.50n | +149.39% | +656.42% | | Pow | 1.1^60 | 950.90n | 3302.50n | 4599.50n | +247.32% | +383.73% | | Pow | 1.01^600 | 3.45µ | 10.67µ | 18.67µ | +209.04% | +440.89% | | Pow | 1.001^6000 | 5.94µ | 20.50µ | 722.22µ | +244.88% | +12052.44% | +| Sqrt | √2 | 3.49µ | 4.68µ | 498.36µ | +34.07% | +14187.84% | | Parse | 1 | 16.52n | 76.30n | 136.55n | +362.00% | +726.82% | | Parse | 123.456 | 47.37n | 176.90n | 242.60n | +273.44% | +412.14% | | Parse | 123456789.1234567890 | 85.49n | 224.15n | 497.95n | +162.19% | +482.47% | diff --git a/decimal.go b/decimal.go index e187c17..30a67d4 100644 --- a/decimal.go +++ b/decimal.go @@ -931,6 +931,7 @@ func (d Decimal) Format(state fmt.State, verb rune) { } // Writing result + //nolint:errcheck switch verb { case 'q', 'Q', 's', 'S', 'v', 'V', 'f', 'F', 'k', 'K': state.Write(buf) @@ -1270,8 +1271,6 @@ func (d Decimal) Pow(power int) (Decimal, error) { // of digits after the decimal point that should be considered significant. // If any of the significant digits are lost during rounding, the method will // return an overflow error. -// This method is useful for financial calculations where the scale should be -// equal to or greater than the currency's scale. func (d Decimal) PowExact(power, scale int) (Decimal, error) { if scale < MinScale || scale > MaxScale { return Decimal{}, fmt.Errorf("computing [%v^%v]: %w", d, power, errScaleRange) @@ -1429,6 +1428,78 @@ func (d Decimal) powBint(power, minScale int) (Decimal, error) { return newFromBint(eneg, ecoef, escale, minScale) } +// Sqrt computes the square root of a decimal. +// +// Sqrt returns an error if the decimal is negative. +func (d Decimal) Sqrt() (Decimal, error) { + return d.SqrtExact(0) +} + +// SqrtExact is similar to [Decimal.Sqrt], but it allows you to specify the number of digits +// after the decimal point that should be considered significant. +// If any of the significant digits are lost during rounding, the method will return an error. +func (d Decimal) SqrtExact(scale int) (Decimal, error) { + // Special case: negative + if d.IsNeg() { + return Decimal{}, fmt.Errorf("computing sqrt(%v): %w", d, errInvalidOperation) + } + + // Special case: zero + if d.IsZero() { + scale = max(scale, d.Scale()/2) + return newSafe(false, 0, scale) + } + + // General case + e, err := d.sqrtBint(scale) + if err != nil { + return Decimal{}, fmt.Errorf("computing sqrt(%v): %w", d, err) + } + + // Preferred scale + scale = max(scale, d.Scale()/2) + e = e.Trim(scale) + + return e, nil +} + +// sqrtBint computes the square root of a decimal using *big.Int arithmetic. +func (d Decimal) sqrtBint(minScale int) (Decimal, error) { + dcoef := getBint() + defer putBint(dcoef) + dcoef.setFint(d.coef) + + ecoef := getBint() + defer putBint(ecoef) + + fcoef := getBint() + defer putBint(fcoef) + + two := getBint() + defer putBint(two) + two.setFint(2) + + // Babylonian method + dcoef.lsh(dcoef, 2*MaxScale-d.Scale()) + ecoef.quo(dcoef, two) + fcoef.setFint(1) + fcoef.lsh(fcoef, 2*MaxScale) + fcoef.add(fcoef, ecoef) + dcoef.lsh(dcoef, 2*MaxScale) + + for { + if ecoef.cmp(fcoef) == 0 { + break + } + fcoef.setBint(ecoef) + ecoef.quo(dcoef, ecoef) + ecoef.add(ecoef, fcoef) + ecoef.quo(ecoef, two) + } + + return newFromBint(false, ecoef, 2*MaxScale, minScale) +} + // Add returns the (possibly rounded) sum of decimals d and e. // // Add returns an error if the integer part of the result has more than [MaxPrec] digits. diff --git a/decimal_test.go b/decimal_test.go index e7ee9c8..29113be 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -2319,6 +2319,80 @@ func TestDecimal_Pow(t *testing.T) { }) } +func TestDecimal_Sqrt(t *testing.T) { + t.Run("success", func(t *testing.T) { + tests := []struct { + d, want string + }{ + // Zeros + {"0", "0"}, + {"0.0", "0"}, + {"0.00", "0.0"}, + {"0.000", "0.0"}, + {"0.0000", "0.00"}, + + // Numbers + {"0", "0"}, + {"1", "1"}, + {"2", "1.414213562373095049"}, + {"3", "1.732050807568877294"}, + {"4", "2"}, + {"5", "2.236067977499789696"}, + {"6", "2.449489742783178098"}, + {"7", "2.645751311064590591"}, + {"8", "2.828427124746190098"}, + {"9", "3"}, + {"10", "3.162277660168379332"}, + {"11", "3.316624790355399849"}, + {"12", "3.464101615137754587"}, + {"13", "3.605551275463989293"}, + {"14", "3.741657386773941386"}, + {"15", "3.872983346207416885"}, + {"16", "4"}, + {"17", "4.12310562561766055"}, + {"18", "4.242640687119285146"}, + {"19", "4.358898943540673552"}, + {"20", "4.472135954999579393"}, + {"21", "4.582575694955840007"}, + {"22", "4.690415759823429555"}, + {"23", "4.795831523312719542"}, + {"24", "4.898979485566356196"}, + {"25", "5"}, + + // Edge cases + {"0.0000000000000000001", "0.000000000316227766"}, + {"9999999999999999999", "3162277660.168379332"}, + } + for _, tt := range tests { + d := MustParse(tt.d) + got, err := d.Sqrt() + if err != nil { + t.Errorf("%q.Sqrt() failed: %v", d, err) + continue + } + want := MustParse(tt.want) + if got != want { + t.Errorf("%q.Sqrt() = %q, want %q", d, got, want) + } + } + }) + + t.Run("error", func(t *testing.T) { + tests := map[string]string{ + "negative": "-1", + } + for name, d := range tests { + t.Run(name, func(t *testing.T) { + d := MustParse(d) + _, err := d.Sqrt() + if err == nil { + t.Errorf("%q.Sqrt() did not fail", d) + } + }) + } + }) +} + func TestDecimal_Abs(t *testing.T) { tests := []struct { d, want string @@ -3438,6 +3512,71 @@ func FuzzDecimal_Cmp(f *testing.F) { ) } +func FuzzDecimal_Sqrt(f *testing.F) { + for _, d := range corpus { + f.Add(d.neg, d.scale, d.coef) + } + + f.Fuzz( + func(t *testing.T, neg bool, scale int, coef uint64) { + if neg { + t.Skip() + return + } + if scale < 0 || MaxScale < scale { + t.Skip() + return + } + want, err := newSafe(neg, fint(coef), scale) + if err != nil { + t.Skip() + return + } + d, err := want.Sqrt() + if err != nil { + t.Errorf("%q.Sqrt() failed: %v", want, err) + return + } + got, err := d.Pow(2) + if err != nil { + if errors.Is(err, errDecimalOverflow) { + t.Skip() // Decimal overflow is an expected error here + } else { + t.Errorf("%q.Pow(2) failed: %v", d, err) + } + return + } + if cmp, err := cmp3ULP(got, want); err != nil { + t.Errorf("cmpULP(%q, %q) failed: %v", got, want, err) + } else if cmp != 0 { + t.Errorf("%q.Sqrt().Pow(2) = %q, want %q", want, got, want) + return + } + }, + ) +} + +// cmp3ULP compares decimals and returns 0 if they are within 3 ULPs. +func cmp3ULP(d, e Decimal) (int, error) { + three, err := New(3, 0) + if err != nil { + return 0, err + } + dist, err := d.SubAbs(e) + if err != nil { + return 0, err + } + ulp := d.ULP().Min(e.ULP()) + tlr, err := ulp.Mul(three) + if err != nil { + return 0, err + } + if dist.Cmp(tlr) <= 0 { + return 0, nil + } + return d.Cmp(e), nil +} + func FuzzDecimal_CmpSub(f *testing.F) { for _, d := range corpus { for _, e := range corpus { diff --git a/doc.go b/doc.go index 229aa04..59ea973 100644 --- a/doc.go +++ b/doc.go @@ -54,7 +54,8 @@ or errors. # Operations -Each arithmetic operation occurs in two steps: +Each arithmetic operation, except for [Decimal.Sqrt] and [Decimal.SqrtExact], +occurs in two steps: 1. The operation is initially performed using uint64 arithmetic. If no overflow occurs, the exact result is immediately returned. @@ -72,11 +73,11 @@ will compute an exact result during step 1. The following rules determine the significance of digits during step 2: - [Decimal.Add], [Decimal.Sub], [Decimal.Mul], [Decimal.FMA], [Decimal.Pow], - [Decimal.Quo], [Decimal.QuoRem], [Decimal.Inv]: + [Decimal.Quo], [Decimal.QuoRem], [Decimal.Inv], [Decimal.Sqrt]: All digits in the integer part are significant, while digits in the fractional part are considered insignificant. - [Decimal.AddExact], [Decimal.SubExact], [Decimal.MulExact], [Decimal.FMAExact], - [Decimal.PowExact], [Decimal.QuoExact]: + [Decimal.PowExact], [Decimal.QuoExact], [Decimal.SqrtExact]: All digits in the integer part are significant. The significance of digits in the fractional part is determined by the scale argument, which is typically equal to the scale of the currency. @@ -141,6 +142,8 @@ Errors are returned in the following cases: - Invalid Operation: [Decimal.Pow] and [Decimal.PowExact] return an error if 0 is raised to a negative power. + [Decimal.Sqrt] and [Decimal.SqrtExact] return an error if the square root + of a negative decimal is requested. - Overflow: Unlike standard integers, there is no "wrap around" for decimals at certain sizes. diff --git a/doc_test.go b/doc_test.go index 3984ade..e975b95 100644 --- a/doc_test.go +++ b/doc_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "encoding/xml" "fmt" + "slices" "strings" "github.com/govalues/decimal" @@ -538,6 +539,37 @@ func ExampleDecimal_PowExact() { // 8.0000 } +func ExampleDecimal_Sqrt() { + d := decimal.MustParse("1") + e := decimal.MustParse("2") + f := decimal.MustParse("3") + g := decimal.MustParse("4") + fmt.Println(d.Sqrt()) + fmt.Println(e.Sqrt()) + fmt.Println(f.Sqrt()) + fmt.Println(g.Sqrt()) + // Output: + // 1 + // 1.414213562373095049 + // 1.732050807568877294 + // 2 +} + +func ExampleDecimal_SqrtExact() { + d := decimal.MustParse("4") + fmt.Println(d.SqrtExact(0)) + fmt.Println(d.SqrtExact(1)) + fmt.Println(d.SqrtExact(2)) + fmt.Println(d.SqrtExact(3)) + fmt.Println(d.SqrtExact(4)) + // Output: + // 2 + // 2.0 + // 2.00 + // 2.000 + // 2.0000 +} + func ExampleDecimal_Add() { d := decimal.MustParse("5.67") e := decimal.MustParse("8") @@ -639,6 +671,28 @@ func ExampleDecimal_Cmp() { // 1 } +func ExampleDecimal_Cmp_slices() { + s := []decimal.Decimal{ + decimal.MustParse("-5.67"), + decimal.MustParse("23"), + decimal.MustParse("0"), + } + fmt.Println(slices.CompareFunc(s, s, decimal.Decimal.Cmp)) + fmt.Println(slices.MaxFunc(s, decimal.Decimal.Cmp)) + fmt.Println(slices.MinFunc(s, decimal.Decimal.Cmp)) + fmt.Println(s, slices.IsSortedFunc(s, decimal.Decimal.Cmp)) + slices.SortFunc(s, decimal.Decimal.Cmp) + fmt.Println(s, slices.IsSortedFunc(s, decimal.Decimal.Cmp)) + fmt.Println(slices.BinarySearchFunc(s, decimal.MustParse("1"), decimal.Decimal.Cmp)) + // Output: + // 0 + // 23 + // -5.67 + // [-5.67 23 0] false + // [-5.67 0 23] true + // 2 false +} + func ExampleDecimal_CmpAbs() { d := decimal.MustParse("-23") e := decimal.MustParse("5.67") @@ -651,6 +705,28 @@ func ExampleDecimal_CmpAbs() { // -1 } +func ExampleDecimal_CmpAbs_slices() { + s := []decimal.Decimal{ + decimal.MustParse("-5.67"), + decimal.MustParse("23"), + decimal.MustParse("0"), + } + fmt.Println(slices.CompareFunc(s, s, decimal.Decimal.CmpAbs)) + fmt.Println(slices.MaxFunc(s, decimal.Decimal.CmpAbs)) + fmt.Println(slices.MinFunc(s, decimal.Decimal.CmpAbs)) + fmt.Println(s, slices.IsSortedFunc(s, decimal.Decimal.CmpAbs)) + slices.SortFunc(s, decimal.Decimal.CmpAbs) + fmt.Println(s, slices.IsSortedFunc(s, decimal.Decimal.CmpAbs)) + fmt.Println(slices.BinarySearchFunc(s, decimal.MustParse("1"), decimal.Decimal.CmpAbs)) + // Output: + // 0 + // 23 + // 0 + // [-5.67 23 0] false + // [0 -5.67 23] true + // 1 false +} + func ExampleDecimal_CmpTotal() { d := decimal.MustParse("2.0") e := decimal.MustParse("2.00") @@ -663,6 +739,28 @@ func ExampleDecimal_CmpTotal() { // -1 } +func ExampleDecimal_CmpTotal_slices() { + s := []decimal.Decimal{ + decimal.MustParse("-5.67"), + decimal.MustParse("23"), + decimal.MustParse("0"), + } + fmt.Println(slices.CompareFunc(s, s, decimal.Decimal.CmpTotal)) + fmt.Println(slices.MaxFunc(s, decimal.Decimal.CmpTotal)) + fmt.Println(slices.MinFunc(s, decimal.Decimal.CmpTotal)) + fmt.Println(s, slices.IsSortedFunc(s, decimal.Decimal.CmpTotal)) + slices.SortFunc(s, decimal.Decimal.CmpTotal) + fmt.Println(s, slices.IsSortedFunc(s, decimal.Decimal.CmpTotal)) + fmt.Println(slices.BinarySearchFunc(s, decimal.MustParse("10"), decimal.Decimal.CmpTotal)) + // Output: + // 0 + // 23 + // -5.67 + // [-5.67 23 0] false + // [-5.67 0 23] true + // 2 false +} + func ExampleDecimal_Max() { d := decimal.MustParse("23") e := decimal.MustParse("-5.67") @@ -890,6 +988,21 @@ func ExampleDecimal_IsNeg() { // false } +func ExampleDecimal_IsNeg_slices() { + s := []decimal.Decimal{ + decimal.MustParse("-5.67"), + decimal.MustParse("23"), + decimal.MustParse("0"), + } + fmt.Println(slices.ContainsFunc(s, decimal.Decimal.IsNeg)) + fmt.Println(slices.IndexFunc(s, decimal.Decimal.IsNeg)) + fmt.Println(slices.DeleteFunc(s, decimal.Decimal.IsNeg)) + // Output: + // true + // 0 + // [23 0] +} + func ExampleDecimal_IsPos() { d := decimal.MustParse("-5.67") e := decimal.MustParse("23") @@ -903,6 +1016,21 @@ func ExampleDecimal_IsPos() { // false } +func ExampleDecimal_IsPos_slices() { + s := []decimal.Decimal{ + decimal.MustParse("-5.67"), + decimal.MustParse("23"), + decimal.MustParse("0"), + } + fmt.Println(slices.ContainsFunc(s, decimal.Decimal.IsPos)) + fmt.Println(slices.IndexFunc(s, decimal.Decimal.IsPos)) + fmt.Println(slices.DeleteFunc(s, decimal.Decimal.IsPos)) + // Output: + // true + // 1 + // [-5.67 0] +} + func ExampleDecimal_IsZero() { d := decimal.MustParse("-5.67") e := decimal.MustParse("23") @@ -916,6 +1044,21 @@ func ExampleDecimal_IsZero() { // true } +func ExampleDecimal_IsZero_slices() { + s := []decimal.Decimal{ + decimal.MustParse("-5.67"), + decimal.MustParse("23"), + decimal.MustParse("0"), + } + fmt.Println(slices.ContainsFunc(s, decimal.Decimal.IsZero)) + fmt.Println(slices.IndexFunc(s, decimal.Decimal.IsZero)) + fmt.Println(slices.DeleteFunc(s, decimal.Decimal.IsZero)) + // Output: + // true + // 2 + // [-5.67 23] +} + func ExampleDecimal_IsInt() { d := decimal.MustParse("1.00") e := decimal.MustParse("1.01")