Skip to content

Commit

Permalink
decimal: implement Sqrt and SqrtExact methods
Browse files Browse the repository at this point in the history
  • Loading branch information
eapenkin authored Jun 22, 2024
1 parent b05531a commit 0beed27
Show file tree
Hide file tree
Showing 7 changed files with 374 additions and 7 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ jobs:
- name: Run fuzzing for fused multiply-addition
run: go test -fuzztime 60s -fuzz ^FuzzDecimal_FMA$

- name: Run fuzzing for square root
run: go test -fuzztime 20s -fuzz ^FuzzDecimal_Sqrt$

- name: Run fuzzing for division
run: go test -fuzztime 20s -fuzz ^FuzzDecimal_Quo$

Expand Down
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## [0.1.28] - 2024-06-22

### Added

- Implemented `Decimal.Sqrt`, `Decimal.SqrtExact`.

## [0.1.27] - 2024-05-19

### Changed
Expand Down
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ func main() {

fmt.Println(d.Mul(e)) // 8 * 12.5
fmt.Println(d.FMA(e, f)) // 8 * 12.5 + 2.567
fmt.Println(d.Pow(2)) // 8 ^ 2
fmt.Println(d.Pow(2)) //
fmt.Println(d.Sqrt()) // √8

fmt.Println(d.Quo(e)) // 8 ÷ 12.5
fmt.Println(d.QuoRem(e)) // 8 div 12.5, 8 mod 12.5
Expand Down Expand Up @@ -127,10 +128,11 @@ cpu: AMD Ryzen 7 3700C with Radeon Vega Mobile Gfx
| Add | 5 + 6 | 16.06n | 74.88n | 140.90n | +366.22% | +777.33% |
| Mul | 2 * 3 | 16.93n | 62.20n | 146.00n | +267.40% | +762.37% |
| QuoExact | 2 ÷ 4 | 59.52n | 176.95n | 657.40n | +197.30% | +1004.50% |
| QuoInfinite | 2 ÷ 3 | 391.60n | 976.8n | 2962.50n | +149.39% | +656.42% |
| QuoInfinite | 2 ÷ 3 | 391.60n | 976.80n | 2962.50n | +149.39% | +656.42% |
| Pow | 1.1^60 | 950.90n | 3302.50n | 4599.50n | +247.32% | +383.73% |
| Pow | 1.01^600 | 3.45µ | 10.67µ | 18.67µ | +209.04% | +440.89% |
| Pow | 1.001^6000 | 5.94µ | 20.50µ | 722.22µ | +244.88% | +12052.44% |
| Sqrt | √2 | 3.49µ | 4.68µ | 498.36µ | +34.07% | +14187.84% |
| Parse | 1 | 16.52n | 76.30n | 136.55n | +362.00% | +726.82% |
| Parse | 123.456 | 47.37n | 176.90n | 242.60n | +273.44% | +412.14% |
| Parse | 123456789.1234567890 | 85.49n | 224.15n | 497.95n | +162.19% | +482.47% |
Expand Down
75 changes: 73 additions & 2 deletions decimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,7 @@ func (d Decimal) Format(state fmt.State, verb rune) {
}

// Writing result
//nolint:errcheck
switch verb {
case 'q', 'Q', 's', 'S', 'v', 'V', 'f', 'F', 'k', 'K':
state.Write(buf)
Expand Down Expand Up @@ -1270,8 +1271,6 @@ func (d Decimal) Pow(power int) (Decimal, error) {
// of digits after the decimal point that should be considered significant.
// If any of the significant digits are lost during rounding, the method will
// return an overflow error.
// This method is useful for financial calculations where the scale should be
// equal to or greater than the currency's scale.
func (d Decimal) PowExact(power, scale int) (Decimal, error) {
if scale < MinScale || scale > MaxScale {
return Decimal{}, fmt.Errorf("computing [%v^%v]: %w", d, power, errScaleRange)
Expand Down Expand Up @@ -1429,6 +1428,78 @@ func (d Decimal) powBint(power, minScale int) (Decimal, error) {
return newFromBint(eneg, ecoef, escale, minScale)
}

// Sqrt computes the square root of a decimal.
//
// Sqrt returns an error if the decimal is negative.
func (d Decimal) Sqrt() (Decimal, error) {
return d.SqrtExact(0)
}

// SqrtExact is similar to [Decimal.Sqrt], but it allows you to specify the number of digits
// after the decimal point that should be considered significant.
// If any of the significant digits are lost during rounding, the method will return an error.
func (d Decimal) SqrtExact(scale int) (Decimal, error) {
// Special case: negative
if d.IsNeg() {
return Decimal{}, fmt.Errorf("computing sqrt(%v): %w", d, errInvalidOperation)
}

// Special case: zero
if d.IsZero() {
scale = max(scale, d.Scale()/2)
return newSafe(false, 0, scale)
}

// General case
e, err := d.sqrtBint(scale)
if err != nil {
return Decimal{}, fmt.Errorf("computing sqrt(%v): %w", d, err)
}

// Preferred scale
scale = max(scale, d.Scale()/2)
e = e.Trim(scale)

return e, nil
}

// sqrtBint computes the square root of a decimal using *big.Int arithmetic.
func (d Decimal) sqrtBint(minScale int) (Decimal, error) {
dcoef := getBint()
defer putBint(dcoef)
dcoef.setFint(d.coef)

ecoef := getBint()
defer putBint(ecoef)

fcoef := getBint()
defer putBint(fcoef)

two := getBint()
defer putBint(two)
two.setFint(2)

// Babylonian method
dcoef.lsh(dcoef, 2*MaxScale-d.Scale())
ecoef.quo(dcoef, two)
fcoef.setFint(1)
fcoef.lsh(fcoef, 2*MaxScale)
fcoef.add(fcoef, ecoef)
dcoef.lsh(dcoef, 2*MaxScale)

for {
if ecoef.cmp(fcoef) == 0 {
break
}
fcoef.setBint(ecoef)
ecoef.quo(dcoef, ecoef)
ecoef.add(ecoef, fcoef)
ecoef.quo(ecoef, two)
}

return newFromBint(false, ecoef, 2*MaxScale, minScale)
}

// Add returns the (possibly rounded) sum of decimals d and e.
//
// Add returns an error if the integer part of the result has more than [MaxPrec] digits.
Expand Down
139 changes: 139 additions & 0 deletions decimal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2319,6 +2319,80 @@ func TestDecimal_Pow(t *testing.T) {
})
}

func TestDecimal_Sqrt(t *testing.T) {
t.Run("success", func(t *testing.T) {
tests := []struct {
d, want string
}{
// Zeros
{"0", "0"},
{"0.0", "0"},
{"0.00", "0.0"},
{"0.000", "0.0"},
{"0.0000", "0.00"},

// Numbers
{"0", "0"},
{"1", "1"},
{"2", "1.414213562373095049"},
{"3", "1.732050807568877294"},
{"4", "2"},
{"5", "2.236067977499789696"},
{"6", "2.449489742783178098"},
{"7", "2.645751311064590591"},
{"8", "2.828427124746190098"},
{"9", "3"},
{"10", "3.162277660168379332"},
{"11", "3.316624790355399849"},
{"12", "3.464101615137754587"},
{"13", "3.605551275463989293"},
{"14", "3.741657386773941386"},
{"15", "3.872983346207416885"},
{"16", "4"},
{"17", "4.12310562561766055"},
{"18", "4.242640687119285146"},
{"19", "4.358898943540673552"},
{"20", "4.472135954999579393"},
{"21", "4.582575694955840007"},
{"22", "4.690415759823429555"},
{"23", "4.795831523312719542"},
{"24", "4.898979485566356196"},
{"25", "5"},

// Edge cases
{"0.0000000000000000001", "0.000000000316227766"},
{"9999999999999999999", "3162277660.168379332"},
}
for _, tt := range tests {
d := MustParse(tt.d)
got, err := d.Sqrt()
if err != nil {
t.Errorf("%q.Sqrt() failed: %v", d, err)
continue
}
want := MustParse(tt.want)
if got != want {
t.Errorf("%q.Sqrt() = %q, want %q", d, got, want)
}
}
})

t.Run("error", func(t *testing.T) {
tests := map[string]string{
"negative": "-1",
}
for name, d := range tests {
t.Run(name, func(t *testing.T) {
d := MustParse(d)
_, err := d.Sqrt()
if err == nil {
t.Errorf("%q.Sqrt() did not fail", d)
}
})
}
})
}

func TestDecimal_Abs(t *testing.T) {
tests := []struct {
d, want string
Expand Down Expand Up @@ -3438,6 +3512,71 @@ func FuzzDecimal_Cmp(f *testing.F) {
)
}

func FuzzDecimal_Sqrt(f *testing.F) {
for _, d := range corpus {
f.Add(d.neg, d.scale, d.coef)
}

f.Fuzz(
func(t *testing.T, neg bool, scale int, coef uint64) {
if neg {
t.Skip()
return
}
if scale < 0 || MaxScale < scale {
t.Skip()
return
}
want, err := newSafe(neg, fint(coef), scale)
if err != nil {
t.Skip()
return
}
d, err := want.Sqrt()
if err != nil {
t.Errorf("%q.Sqrt() failed: %v", want, err)
return
}
got, err := d.Pow(2)
if err != nil {
if errors.Is(err, errDecimalOverflow) {
t.Skip() // Decimal overflow is an expected error here
} else {
t.Errorf("%q.Pow(2) failed: %v", d, err)
}
return
}
if cmp, err := cmp3ULP(got, want); err != nil {
t.Errorf("cmpULP(%q, %q) failed: %v", got, want, err)
} else if cmp != 0 {
t.Errorf("%q.Sqrt().Pow(2) = %q, want %q", want, got, want)
return
}
},
)
}

// cmp3ULP compares decimals and returns 0 if they are within 3 ULPs.
func cmp3ULP(d, e Decimal) (int, error) {
three, err := New(3, 0)
if err != nil {
return 0, err
}
dist, err := d.SubAbs(e)
if err != nil {
return 0, err
}
ulp := d.ULP().Min(e.ULP())
tlr, err := ulp.Mul(three)
if err != nil {
return 0, err
}
if dist.Cmp(tlr) <= 0 {
return 0, nil
}
return d.Cmp(e), nil
}

func FuzzDecimal_CmpSub(f *testing.F) {
for _, d := range corpus {
for _, e := range corpus {
Expand Down
9 changes: 6 additions & 3 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ or errors.
# Operations
Each arithmetic operation occurs in two steps:
Each arithmetic operation, except for [Decimal.Sqrt] and [Decimal.SqrtExact],
occurs in two steps:
1. The operation is initially performed using uint64 arithmetic.
If no overflow occurs, the exact result is immediately returned.
Expand All @@ -72,11 +73,11 @@ will compute an exact result during step 1.
The following rules determine the significance of digits during step 2:
- [Decimal.Add], [Decimal.Sub], [Decimal.Mul], [Decimal.FMA], [Decimal.Pow],
[Decimal.Quo], [Decimal.QuoRem], [Decimal.Inv]:
[Decimal.Quo], [Decimal.QuoRem], [Decimal.Inv], [Decimal.Sqrt]:
All digits in the integer part are significant, while digits in the
fractional part are considered insignificant.
- [Decimal.AddExact], [Decimal.SubExact], [Decimal.MulExact], [Decimal.FMAExact],
[Decimal.PowExact], [Decimal.QuoExact]:
[Decimal.PowExact], [Decimal.QuoExact], [Decimal.SqrtExact]:
All digits in the integer part are significant. The significance of digits
in the fractional part is determined by the scale argument, which is typically
equal to the scale of the currency.
Expand Down Expand Up @@ -141,6 +142,8 @@ Errors are returned in the following cases:
- Invalid Operation:
[Decimal.Pow] and [Decimal.PowExact] return an error if 0 is raised to
a negative power.
[Decimal.Sqrt] and [Decimal.SqrtExact] return an error if the square root
of a negative decimal is requested.
- Overflow:
Unlike standard integers, there is no "wrap around" for decimals at certain sizes.
Expand Down
Loading

0 comments on commit 0beed27

Please sign in to comment.