Skip to content
8 changes: 8 additions & 0 deletions price/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ var (
ErrDivisionByZero = errors.New("division by 0")
// ErrOverflow is returned when a price operation would result in an integer overflow
ErrOverflow = errors.New("overflow")
// ErrNoNegatives is returned when a price operation is given a negative number
ErrNoNegatives = errors.New("negative numbers are not allowed")
)

// Parse calculates and returns the best rational approximation of the given
Expand Down Expand Up @@ -181,6 +183,9 @@ func MulFractionRoundDown(x int64, n int64, d int64) (int64, error) {
if d == 0 {
return 0, ErrDivisionByZero
}
if x < 0 || n < 0 || d < 0 {
return 0, ErrNoNegatives
}

hi, lo := bits.Mul64(uint64(x), uint64(n))

Expand All @@ -202,6 +207,9 @@ func mulFractionRoundUp(x int64, n int64, d int64) (int64, error) {
if d == 0 {
return 0, ErrDivisionByZero
}
if x < 0 || n < 0 || d < 0 {
return 0, ErrNoNegatives
}

hi, lo := bits.Mul64(uint64(x), uint64(n))
lo, carry := bits.Add64(lo, uint64(d-1), 0)
Expand Down
100 changes: 91 additions & 9 deletions xdr/price.go
Original file line number Diff line number Diff line change
@@ -1,26 +1,46 @@
package xdr

import (
"fmt"
"math/big"
)

// String returns a string representation of `p`
// String returns a string representation of `p` if `p` is valid,
// and string indicating the price's invalidity otherwise.
// Satisfies the fmt.Stringer interface.
func (p Price) String() string {
if err := p.Validate(); err != nil {
return fmt.Sprintf("<invalid price (%d/%d): %v>", p.N, p.D, err)
}
return big.NewRat(int64(p.N), int64(p.D)).FloatString(7)
}

// Equal returns whether the price's value is the same,
// TryEqual returns whether the price's value is the same,
// taking into account denormalized representation
// (e.g. Price{1, 2}.EqualValue(Price{2,4}) == true )
func (p Price) Equal(q Price) bool {
// See the Cheaper() method for the reasoning behind this:
return uint64(p.N)*uint64(q.D) == uint64(q.N)*uint64(p.D)
// Returns an error if either price is invalid.
func (p Price) TryEqual(q Price) (bool, error) {
if err := p.Validate(); err != nil {
return false, fmt.Errorf("invalid price p: %w", err)
}
if err := q.Validate(); err != nil {
return false, fmt.Errorf("invalid price q: %w", err)
}
// See the TryCheaper() method for the reasoning behind this:
return uint64(p.N)*uint64(q.D) == uint64(q.N)*uint64(p.D), nil

Check failure on line 30 in xdr/price.go

View workflow job for this annotation

GitHub Actions / golangci

G115: integer overflow conversion int32 -> uint64 (gosec)
}

// Cheaper indicates if the Price's value is lower,
// taking into account denormalized representation
// TryCheaper indicates if the Price's value is lower,
// taking into account denormalized representation.
// (e.g. Price{1, 2}.Cheaper(Price{2,4}) == false )
func (p Price) Cheaper(q Price) bool {
// Returns an error if either price is invalid
func (p Price) TryCheaper(q Price) (bool, error) {
if err := p.Validate(); err != nil {
return false, fmt.Errorf("invalid price p: %w", err)
}
if err := q.Validate(); err != nil {
return false, fmt.Errorf("invalid price q: %w", err)
}
// To avoid float precision issues when naively comparing Price.N/Price.D,
// we use the cross product instead:
//
Expand All @@ -31,17 +51,79 @@
// (p.N / p.D) * (p.D * q.D) < (q.N / q.D) * (p.D * q.D)
// <==>
// p.N * q.D < q.N * p.D
return uint64(p.N)*uint64(q.D) < uint64(q.N)*uint64(p.D), nil

Check failure on line 54 in xdr/price.go

View workflow job for this annotation

GitHub Actions / golangci

G115: integer overflow conversion int32 -> uint64 (gosec)
}

// TryNormalize sets the price to its rational canonical form.
// Returns an error if the price is invalid
func (p *Price) TryNormalize() error {
if err := p.Validate(); err != nil {
return fmt.Errorf("invalid price: %w", err)
}
r := big.NewRat(int64(p.N), int64(p.D))
p.N = Int32(r.Num().Int64())

Check failure on line 64 in xdr/price.go

View workflow job for this annotation

GitHub Actions / golangci

G115: integer overflow conversion int64 -> int32 (gosec)
p.D = Int32(r.Denom().Int64())

Check failure on line 65 in xdr/price.go

View workflow job for this annotation

GitHub Actions / golangci

G115: integer overflow conversion int64 -> int32 (gosec)
return nil
}

// TryInvert inverts Price.
// Returns an error if the price is invalid
func (p *Price) TryInvert() error {
if err := p.Validate(); err != nil {
return fmt.Errorf("invalid price: %w", err)
}
p.N, p.D = p.D, p.N
return nil
}

// Validate checks if the price is valid and returns an error if not.
func (p Price) Validate() error {
if p.N == 0 {
return fmt.Errorf("price cannot be 0: %d/%d", p.N, p.D)
}
if p.D == 0 {
return fmt.Errorf("price denominator cannot be 0: %d/%d", p.N, p.D)
}
if p.N < 0 || p.D < 0 {
return fmt.Errorf("price cannot be negative: %d/%d", p.N, p.D)
}
return nil
}

// Equal returns whether the price's value is the same,
// taking into account denormalized representation
// (e.g. Price{1, 2}.EqualValue(Price{2,4}) == true ).
// It does not validate the prices and may produce incorrect results for invalid inputs.
//
// Deprecated: Use TryEqual instead, which returns an error for invalid prices.
func (p Price) Equal(q Price) bool {
return uint64(p.N)*uint64(q.D) == uint64(q.N)*uint64(p.D)

Check failure on line 100 in xdr/price.go

View workflow job for this annotation

GitHub Actions / golangci

G115: integer overflow conversion int32 -> uint64 (gosec)
}

// Cheaper indicates if the Price's value is lower,
// taking into account denormalized representation
// (e.g. Price{1, 2}.Cheaper(Price{2,4}) == false ).
// It does not validate the prices and may produce incorrect results for invalid inputs.
//
// Deprecated: Use TryCheaper instead, which returns an error for invalid prices.
func (p Price) Cheaper(q Price) bool {
return uint64(p.N)*uint64(q.D) < uint64(q.N)*uint64(p.D)
}

// Normalize sets Price to its rational canonical form
// Normalize sets Price to its rational canonical form.
// It panics if the price denominator is zero.
//
// Deprecated: Use TryNormalize instead, which returns an error for invalid prices.
func (p *Price) Normalize() {
r := big.NewRat(int64(p.N), int64(p.D))
p.N = Int32(r.Num().Int64())
p.D = Int32(r.Denom().Int64())
}

// Invert inverts Price.
// It may set a Price with zero denominator if the original price's numerator is zero.
//
// Deprecated: Use TryInvert instead, which returns an error for invalid prices.
func (p *Price) Invert() {
p.N, p.D = p.D, p.N
}
84 changes: 81 additions & 3 deletions xdr/price_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,52 @@ import (
"github.com/stellar/go-stellar-sdk/xdr"
)

func makeAssertBoolNoError(t *testing.T) (func(bool, error), func(bool, error)) {
t.Helper()
assertTrueNoError := func(res bool, err error) {
t.Helper()
assert.NoError(t, err)
assert.True(t, res)
}
assertFalseNoError := func(res bool, err error) {
t.Helper()
assert.NoError(t, err)
assert.False(t, res)
}
return assertTrueNoError, assertFalseNoError
}

func TestPriceInvert(t *testing.T) {
p := xdr.Price{N: 1, D: 2}
assert.NoError(t, p.TryInvert())
assert.Equal(t, xdr.Price{N: 2, D: 1}, p)

// Using deprecated Price.Invert() method
p = xdr.Price{N: 1, D: 2}
p.Invert()
assert.Equal(t, xdr.Price{N: 2, D: 1}, p)
}

func TestPriceEqual(t *testing.T) {
assertTrueNoError, assertFalseNoError := makeAssertBoolNoError(t)

// canonical
assertTrueNoError(xdr.Price{N: 1, D: 2}.TryEqual(xdr.Price{N: 1, D: 2}))
assertFalseNoError(xdr.Price{N: 1, D: 2}.TryEqual(xdr.Price{N: 2, D: 3}))

// not canonical
assertTrueNoError(xdr.Price{N: 1, D: 2}.TryEqual(xdr.Price{N: 5, D: 10}))
assertTrueNoError(xdr.Price{N: 5, D: 10}.TryEqual(xdr.Price{N: 1, D: 2}))
assertTrueNoError(xdr.Price{N: 5, D: 10}.TryEqual(xdr.Price{N: 50, D: 100}))
assertFalseNoError(xdr.Price{N: 1, D: 3}.TryEqual(xdr.Price{N: 5, D: 10}))
assertFalseNoError(xdr.Price{N: 5, D: 10}.TryEqual(xdr.Price{N: 1, D: 3}))
assertFalseNoError(xdr.Price{N: 5, D: 15}.TryEqual(xdr.Price{N: 50, D: 100}))

// canonical using deprecated Price.Equal() method
assert.True(t, xdr.Price{N: 1, D: 2}.Equal(xdr.Price{N: 1, D: 2}))
assert.False(t, xdr.Price{N: 1, D: 2}.Equal(xdr.Price{N: 2, D: 3}))

// not canonical
// not canonical using deprecated Price.Equal() method
assert.True(t, xdr.Price{N: 1, D: 2}.Equal(xdr.Price{N: 5, D: 10}))
assert.True(t, xdr.Price{N: 5, D: 10}.Equal(xdr.Price{N: 1, D: 2}))
assert.True(t, xdr.Price{N: 5, D: 10}.Equal(xdr.Price{N: 50, D: 100}))
Expand All @@ -29,12 +63,24 @@ func TestPriceEqual(t *testing.T) {
}

func TestPriceCheaper(t *testing.T) {
assertTrueNoError, assertFalseNoError := makeAssertBoolNoError(t)

// canonical
assertTrueNoError(xdr.Price{N: 1, D: 4}.TryCheaper(xdr.Price{N: 1, D: 3}))
assertFalseNoError(xdr.Price{N: 1, D: 3}.TryCheaper(xdr.Price{N: 1, D: 4}))
assertFalseNoError(xdr.Price{N: 1, D: 4}.TryCheaper(xdr.Price{N: 1, D: 4}))

// not canonical
assertTrueNoError(xdr.Price{N: 10, D: 40}.TryCheaper(xdr.Price{N: 3, D: 9}))
assertFalseNoError(xdr.Price{N: 3, D: 9}.TryCheaper(xdr.Price{N: 10, D: 40}))
assertFalseNoError(xdr.Price{N: 10, D: 40}.TryCheaper(xdr.Price{N: 10, D: 40}))

// canonical using deprecated Price.Cheaper() method
assert.True(t, xdr.Price{N: 1, D: 4}.Cheaper(xdr.Price{N: 1, D: 3}))
assert.False(t, xdr.Price{N: 1, D: 3}.Cheaper(xdr.Price{N: 1, D: 4}))
assert.False(t, xdr.Price{N: 1, D: 4}.Cheaper(xdr.Price{N: 1, D: 4}))

// not canonical
// not canonical using deprecated Price.Cheaper() method
assert.True(t, xdr.Price{N: 10, D: 40}.Cheaper(xdr.Price{N: 3, D: 9}))
assert.False(t, xdr.Price{N: 3, D: 9}.Cheaper(xdr.Price{N: 10, D: 40}))
assert.False(t, xdr.Price{N: 10, D: 40}.Cheaper(xdr.Price{N: 10, D: 40}))
Expand All @@ -43,11 +89,43 @@ func TestPriceCheaper(t *testing.T) {
func TestNormalize(t *testing.T) {
// canonical
p := xdr.Price{N: 1, D: 4}
p.Normalize()
assert.NoError(t, p.TryNormalize())
assert.Equal(t, xdr.Price{N: 1, D: 4}, p)

// not canonical
p = xdr.Price{N: 500, D: 2000}
assert.NoError(t, p.TryNormalize())
assert.Equal(t, xdr.Price{N: 1, D: 4}, p)

// canonical using deprecated Price.Normalize() method
p = xdr.Price{N: 1, D: 4}
p.Normalize()
assert.Equal(t, xdr.Price{N: 1, D: 4}, p)

// not canonical using deprecated Price.Normalize() method
p = xdr.Price{N: 500, D: 2000}
p.Normalize()
assert.Equal(t, xdr.Price{N: 1, D: 4}, p)
}

func TestInvalidPrices(t *testing.T) {
negativePrices := []xdr.Price{{N: -1, D: 4}, {N: 1, D: -4}, {N: -1, D: -4}}
zeroPrices := []xdr.Price{{N: 0, D: 4}, {N: 1, D: 0}, {N: 0, D: 0}}

errorOnInvalid := func(p xdr.Price) {
assert.Error(t, p.Validate())
assert.Error(t, p.TryNormalize())
assert.Error(t, p.TryInvert())
_, err := p.TryEqual(xdr.Price{N: 1, D: 4})
assert.Error(t, err)
_, err = p.TryCheaper(xdr.Price{N: 1, D: 4})
assert.Error(t, err)
}

for _, p := range negativePrices {
errorOnInvalid(p)
}
for _, p := range zeroPrices {
errorOnInvalid(p)
}
}
Loading