diff --git a/price/main.go b/price/main.go index 06038bd44c..dbe5f74a38 100644 --- a/price/main.go +++ b/price/main.go @@ -28,6 +28,8 @@ var ( ErrDivisionByZero = errors.New("division by 0") // ErrOverflow is returned when a price operation would result in an integer overflow ErrOverflow = errors.New("overflow") + // ErrNoNegatives is returned when a price operation is given a negative number + ErrNoNegatives = errors.New("negative numbers are not allowed") ) // Parse calculates and returns the best rational approximation of the given @@ -181,6 +183,9 @@ func MulFractionRoundDown(x int64, n int64, d int64) (int64, error) { if d == 0 { return 0, ErrDivisionByZero } + if x < 0 || n < 0 || d < 0 { + return 0, ErrNoNegatives + } hi, lo := bits.Mul64(uint64(x), uint64(n)) @@ -202,6 +207,9 @@ func mulFractionRoundUp(x int64, n int64, d int64) (int64, error) { if d == 0 { return 0, ErrDivisionByZero } + if x < 0 || n < 0 || d < 0 { + return 0, ErrNoNegatives + } hi, lo := bits.Mul64(uint64(x), uint64(n)) lo, carry := bits.Add64(lo, uint64(d-1), 0) diff --git a/xdr/price.go b/xdr/price.go index eed12c7339..7fdb292732 100644 --- a/xdr/price.go +++ b/xdr/price.go @@ -1,26 +1,46 @@ package xdr import ( + "fmt" "math/big" ) -// String returns a string representation of `p` +// String returns a string representation of `p` if `p` is valid, +// and string indicating the price's invalidity otherwise. +// Satisfies the fmt.Stringer interface. func (p Price) String() string { + if err := p.Validate(); err != nil { + return fmt.Sprintf("", p.N, p.D, err) + } return big.NewRat(int64(p.N), int64(p.D)).FloatString(7) } -// Equal returns whether the price's value is the same, +// TryEqual returns whether the price's value is the same, // taking into account denormalized representation // (e.g. Price{1, 2}.EqualValue(Price{2,4}) == true ) -func (p Price) Equal(q Price) bool { - // See the Cheaper() method for the reasoning behind this: - return uint64(p.N)*uint64(q.D) == uint64(q.N)*uint64(p.D) +// Returns an error if either price is invalid. +func (p Price) TryEqual(q Price) (bool, error) { + if err := p.Validate(); err != nil { + return false, fmt.Errorf("invalid price p: %w", err) + } + if err := q.Validate(); err != nil { + return false, fmt.Errorf("invalid price q: %w", err) + } + // See the TryCheaper() method for the reasoning behind this: + return uint64(p.N)*uint64(q.D) == uint64(q.N)*uint64(p.D), nil } -// Cheaper indicates if the Price's value is lower, -// taking into account denormalized representation +// TryCheaper indicates if the Price's value is lower, +// taking into account denormalized representation. // (e.g. Price{1, 2}.Cheaper(Price{2,4}) == false ) -func (p Price) Cheaper(q Price) bool { +// Returns an error if either price is invalid +func (p Price) TryCheaper(q Price) (bool, error) { + if err := p.Validate(); err != nil { + return false, fmt.Errorf("invalid price p: %w", err) + } + if err := q.Validate(); err != nil { + return false, fmt.Errorf("invalid price q: %w", err) + } // To avoid float precision issues when naively comparing Price.N/Price.D, // we use the cross product instead: // @@ -31,10 +51,69 @@ func (p Price) Cheaper(q Price) bool { // (p.N / p.D) * (p.D * q.D) < (q.N / q.D) * (p.D * q.D) // <==> // p.N * q.D < q.N * p.D + return uint64(p.N)*uint64(q.D) < uint64(q.N)*uint64(p.D), nil +} + +// TryNormalize sets the price to its rational canonical form. +// Returns an error if the price is invalid +func (p *Price) TryNormalize() error { + if err := p.Validate(); err != nil { + return fmt.Errorf("invalid price: %w", err) + } + r := big.NewRat(int64(p.N), int64(p.D)) + p.N = Int32(r.Num().Int64()) + p.D = Int32(r.Denom().Int64()) + return nil +} + +// TryInvert inverts Price. +// Returns an error if the price is invalid +func (p *Price) TryInvert() error { + if err := p.Validate(); err != nil { + return fmt.Errorf("invalid price: %w", err) + } + p.N, p.D = p.D, p.N + return nil +} + +// Validate checks if the price is valid and returns an error if not. +func (p Price) Validate() error { + if p.N == 0 { + return fmt.Errorf("price cannot be 0: %d/%d", p.N, p.D) + } + if p.D == 0 { + return fmt.Errorf("price denominator cannot be 0: %d/%d", p.N, p.D) + } + if p.N < 0 || p.D < 0 { + return fmt.Errorf("price cannot be negative: %d/%d", p.N, p.D) + } + return nil +} + +// Equal returns whether the price's value is the same, +// taking into account denormalized representation +// (e.g. Price{1, 2}.EqualValue(Price{2,4}) == true ). +// It does not validate the prices and may produce incorrect results for invalid inputs. +// +// Deprecated: Use TryEqual instead, which returns an error for invalid prices. +func (p Price) Equal(q Price) bool { + return uint64(p.N)*uint64(q.D) == uint64(q.N)*uint64(p.D) +} + +// Cheaper indicates if the Price's value is lower, +// taking into account denormalized representation +// (e.g. Price{1, 2}.Cheaper(Price{2,4}) == false ). +// It does not validate the prices and may produce incorrect results for invalid inputs. +// +// Deprecated: Use TryCheaper instead, which returns an error for invalid prices. +func (p Price) Cheaper(q Price) bool { return uint64(p.N)*uint64(q.D) < uint64(q.N)*uint64(p.D) } -// Normalize sets Price to its rational canonical form +// Normalize sets Price to its rational canonical form. +// It panics if the price denominator is zero. +// +// Deprecated: Use TryNormalize instead, which returns an error for invalid prices. func (p *Price) Normalize() { r := big.NewRat(int64(p.N), int64(p.D)) p.N = Int32(r.Num().Int64()) @@ -42,6 +121,9 @@ func (p *Price) Normalize() { } // Invert inverts Price. +// It may set a Price with zero denominator if the original price's numerator is zero. +// +// Deprecated: Use TryInvert instead, which returns an error for invalid prices. func (p *Price) Invert() { p.N, p.D = p.D, p.N } diff --git a/xdr/price_test.go b/xdr/price_test.go index 875925dd92..34e2011ab5 100644 --- a/xdr/price_test.go +++ b/xdr/price_test.go @@ -8,18 +8,52 @@ import ( "github.com/stellar/go-stellar-sdk/xdr" ) +func makeAssertBoolNoError(t *testing.T) (func(bool, error), func(bool, error)) { + t.Helper() + assertTrueNoError := func(res bool, err error) { + t.Helper() + assert.NoError(t, err) + assert.True(t, res) + } + assertFalseNoError := func(res bool, err error) { + t.Helper() + assert.NoError(t, err) + assert.False(t, res) + } + return assertTrueNoError, assertFalseNoError +} + func TestPriceInvert(t *testing.T) { p := xdr.Price{N: 1, D: 2} + assert.NoError(t, p.TryInvert()) + assert.Equal(t, xdr.Price{N: 2, D: 1}, p) + + // Using deprecated Price.Invert() method + p = xdr.Price{N: 1, D: 2} p.Invert() assert.Equal(t, xdr.Price{N: 2, D: 1}, p) } func TestPriceEqual(t *testing.T) { + assertTrueNoError, assertFalseNoError := makeAssertBoolNoError(t) + // canonical + assertTrueNoError(xdr.Price{N: 1, D: 2}.TryEqual(xdr.Price{N: 1, D: 2})) + assertFalseNoError(xdr.Price{N: 1, D: 2}.TryEqual(xdr.Price{N: 2, D: 3})) + + // not canonical + assertTrueNoError(xdr.Price{N: 1, D: 2}.TryEqual(xdr.Price{N: 5, D: 10})) + assertTrueNoError(xdr.Price{N: 5, D: 10}.TryEqual(xdr.Price{N: 1, D: 2})) + assertTrueNoError(xdr.Price{N: 5, D: 10}.TryEqual(xdr.Price{N: 50, D: 100})) + assertFalseNoError(xdr.Price{N: 1, D: 3}.TryEqual(xdr.Price{N: 5, D: 10})) + assertFalseNoError(xdr.Price{N: 5, D: 10}.TryEqual(xdr.Price{N: 1, D: 3})) + assertFalseNoError(xdr.Price{N: 5, D: 15}.TryEqual(xdr.Price{N: 50, D: 100})) + + // canonical using deprecated Price.Equal() method assert.True(t, xdr.Price{N: 1, D: 2}.Equal(xdr.Price{N: 1, D: 2})) assert.False(t, xdr.Price{N: 1, D: 2}.Equal(xdr.Price{N: 2, D: 3})) - // not canonical + // not canonical using deprecated Price.Equal() method assert.True(t, xdr.Price{N: 1, D: 2}.Equal(xdr.Price{N: 5, D: 10})) assert.True(t, xdr.Price{N: 5, D: 10}.Equal(xdr.Price{N: 1, D: 2})) assert.True(t, xdr.Price{N: 5, D: 10}.Equal(xdr.Price{N: 50, D: 100})) @@ -29,12 +63,24 @@ func TestPriceEqual(t *testing.T) { } func TestPriceCheaper(t *testing.T) { + assertTrueNoError, assertFalseNoError := makeAssertBoolNoError(t) + // canonical + assertTrueNoError(xdr.Price{N: 1, D: 4}.TryCheaper(xdr.Price{N: 1, D: 3})) + assertFalseNoError(xdr.Price{N: 1, D: 3}.TryCheaper(xdr.Price{N: 1, D: 4})) + assertFalseNoError(xdr.Price{N: 1, D: 4}.TryCheaper(xdr.Price{N: 1, D: 4})) + + // not canonical + assertTrueNoError(xdr.Price{N: 10, D: 40}.TryCheaper(xdr.Price{N: 3, D: 9})) + assertFalseNoError(xdr.Price{N: 3, D: 9}.TryCheaper(xdr.Price{N: 10, D: 40})) + assertFalseNoError(xdr.Price{N: 10, D: 40}.TryCheaper(xdr.Price{N: 10, D: 40})) + + // canonical using deprecated Price.Cheaper() method assert.True(t, xdr.Price{N: 1, D: 4}.Cheaper(xdr.Price{N: 1, D: 3})) assert.False(t, xdr.Price{N: 1, D: 3}.Cheaper(xdr.Price{N: 1, D: 4})) assert.False(t, xdr.Price{N: 1, D: 4}.Cheaper(xdr.Price{N: 1, D: 4})) - // not canonical + // not canonical using deprecated Price.Cheaper() method assert.True(t, xdr.Price{N: 10, D: 40}.Cheaper(xdr.Price{N: 3, D: 9})) assert.False(t, xdr.Price{N: 3, D: 9}.Cheaper(xdr.Price{N: 10, D: 40})) assert.False(t, xdr.Price{N: 10, D: 40}.Cheaper(xdr.Price{N: 10, D: 40})) @@ -43,11 +89,43 @@ func TestPriceCheaper(t *testing.T) { func TestNormalize(t *testing.T) { // canonical p := xdr.Price{N: 1, D: 4} - p.Normalize() + assert.NoError(t, p.TryNormalize()) assert.Equal(t, xdr.Price{N: 1, D: 4}, p) // not canonical p = xdr.Price{N: 500, D: 2000} + assert.NoError(t, p.TryNormalize()) + assert.Equal(t, xdr.Price{N: 1, D: 4}, p) + + // canonical using deprecated Price.Normalize() method + p = xdr.Price{N: 1, D: 4} + p.Normalize() + assert.Equal(t, xdr.Price{N: 1, D: 4}, p) + + // not canonical using deprecated Price.Normalize() method + p = xdr.Price{N: 500, D: 2000} p.Normalize() assert.Equal(t, xdr.Price{N: 1, D: 4}, p) } + +func TestInvalidPrices(t *testing.T) { + negativePrices := []xdr.Price{{N: -1, D: 4}, {N: 1, D: -4}, {N: -1, D: -4}} + zeroPrices := []xdr.Price{{N: 0, D: 4}, {N: 1, D: 0}, {N: 0, D: 0}} + + errorOnInvalid := func(p xdr.Price) { + assert.Error(t, p.Validate()) + assert.Error(t, p.TryNormalize()) + assert.Error(t, p.TryInvert()) + _, err := p.TryEqual(xdr.Price{N: 1, D: 4}) + assert.Error(t, err) + _, err = p.TryCheaper(xdr.Price{N: 1, D: 4}) + assert.Error(t, err) + } + + for _, p := range negativePrices { + errorOnInvalid(p) + } + for _, p := range zeroPrices { + errorOnInvalid(p) + } +}