Skip to content

Commit 971bae0

Browse files
authored
Fix (**) and use property tests (#57)
The default implementation of (**) relied on log and incorrectly handled some inputs. The fix is making an explicit implementation using the pow function. To test the changes, tests for functions from the Floating typeclass are re-written using property tests. There are a few helper functions to make writing the actual properties easy. More tests can be converted to properties, but this is left for another PR.
1 parent 90812e0 commit 971bae0

File tree

3 files changed

+122
-48
lines changed

3 files changed

+122
-48
lines changed

arrayfire.cabal

+1
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ test-suite test
151151
base < 5,
152152
directory,
153153
hspec,
154+
HUnit,
154155
QuickCheck,
155156
quickcheck-classes,
156157
vector

src/ArrayFire/Orphans.hs

+4
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,12 @@ instance forall a . (Ord a, AFType a, Fractional a) => Floating (Array a) where
5050
pi = A.scalar @a 3.14159
5151
exp = A.exp @a
5252
log = A.log @a
53+
sqrt = A.sqrt @a
54+
(**) = A.pow @a
5355
sin = A.sin @a
5456
cos = A.cos @a
57+
tan = A.tan @a
58+
tanh = A.tanh @a
5559
asin = A.asin @a
5660
acos = A.acos @a
5761
atan = A.atan @a

test/ArrayFire/ArithSpec.hs

+117-48
Original file line numberDiff line numberDiff line change
@@ -1,99 +1,168 @@
1+
{-# LANGUAGE RankNTypes #-}
2+
{-# LANGUAGE ScopedTypeVariables #-}
13
{-# LANGUAGE TypeApplications #-}
4+
25
module ArrayFire.ArithSpec where
36

4-
import ArrayFire hiding (acos)
5-
import Prelude hiding (abs, sqrt, div, and, or, not, isNaN)
6-
import Test.Hspec
7+
import ArrayFire (AFType, Array, cast, clamp, getType, isInf, isZero, matrix, maxOf, minOf, mkArray, scalar, vector)
8+
import qualified ArrayFire
9+
import Control.Exception (throwIO)
10+
import Control.Monad (unless, when)
711
import Foreign.C
12+
import GHC.Exts (IsList (..))
13+
import GHC.Stack
14+
import Test.HUnit.Lang (FailureReason (..), HUnitFailure (..))
15+
import Test.Hspec
16+
import Test.Hspec.QuickCheck
17+
import Prelude hiding (div)
18+
19+
compareWith :: (HasCallStack, Show a) => (a -> a -> Bool) -> a -> a -> Expectation
20+
compareWith comparator result expected =
21+
unless (comparator result expected) $ do
22+
throwIO (HUnitFailure location $ ExpectedButGot Nothing expectedMsg actualMsg)
23+
where
24+
expectedMsg = show expected
25+
actualMsg = show result
26+
location = case reverse (toList callStack) of
27+
(_, loc) : _ -> Just loc
28+
[] -> Nothing
29+
30+
class (Num a) => HasEpsilon a where
31+
eps :: a
32+
33+
instance HasEpsilon Float where
34+
eps = 1.1920929e-7
35+
36+
instance HasEpsilon Double where
37+
eps = 2.220446049250313e-16
38+
39+
approxWith :: (Ord a, Num a) => a -> a -> a -> a -> Bool
40+
approxWith rtol atol a b = abs (a - b) <= Prelude.max atol (rtol * Prelude.max (abs a) (abs b))
41+
42+
approx :: (Ord a, HasEpsilon a) => a -> a -> Bool
43+
approx a b = approxWith (2 * eps * Prelude.max (abs a) (abs b)) (4 * eps) a b
44+
45+
shouldBeApprox :: (Ord a, HasEpsilon a, Show a) => a -> a -> Expectation
46+
shouldBeApprox = compareWith approx
47+
48+
evalf :: (AFType a) => Array a -> a
49+
evalf = ArrayFire.getScalar
50+
51+
shouldMatchBuiltin ::
52+
(AFType a, Ord a, RealFloat a, HasEpsilon a, Show a) =>
53+
(Array a -> Array a) ->
54+
(a -> a) ->
55+
a ->
56+
Expectation
57+
shouldMatchBuiltin f f' x
58+
| isInfinite y && isInfinite y' = pure ()
59+
| Prelude.isNaN y && Prelude.isNaN y' = pure ()
60+
| otherwise = y `shouldBeApprox` y'
61+
where
62+
y = evalf (f (scalar x))
63+
y' = f' x
64+
65+
shouldMatchBuiltin2 ::
66+
(AFType a, Ord a, RealFloat a, HasEpsilon a, Show a) =>
67+
(Array a -> Array a -> Array a) ->
68+
(a -> a -> a) ->
69+
a ->
70+
a ->
71+
Expectation
72+
shouldMatchBuiltin2 f f' a = shouldMatchBuiltin (f (scalar a)) (f' a)
873

974
spec :: Spec
1075
spec =
1176
describe "Arith tests" $ do
1277
it "Should negate scalar value" $ do
1378
negate (scalar @Int 1) `shouldBe` (-1)
1479
it "Should negate a vector" $ do
15-
negate (vector @Int 3 [2,2,2]) `shouldBe` vector @Int 3 [-2,-2,-2]
80+
negate (vector @Int 3 [2, 2, 2]) `shouldBe` vector @Int 3 [-2, -2, -2]
1681
it "Should add two scalar arrays" $ do
1782
scalar @Int 1 + 2 `shouldBe` 3
1883
it "Should add two scalar bool arrays" $ do
1984
scalar @CBool 1 + 0 `shouldBe` 1
2085
it "Should subtract two scalar arrays" $ do
2186
scalar @Int 4 - 2 `shouldBe` 2
2287
it "Should multiply two scalar arrays" $ do
23-
scalar @Double 4 `mul` 2 `shouldBe` 8
88+
scalar @Double 4 `ArrayFire.mul` 2 `shouldBe` 8
2489
it "Should divide two scalar arrays" $ do
25-
div @Double 8 2 `shouldBe` 4
90+
ArrayFire.div @Double 8 2 `shouldBe` 4
2691
it "Should add two matrices" $ do
27-
matrix @Int (2,2) [[1,1],[1,1]] + matrix @Int (2,2) [[1,1],[1,1]]
28-
`shouldBe`
29-
matrix @Int (2,2) [[2,2],[2,2]]
30-
-- Exact comparisons of Double don't make sense here, so we just check that the result is
31-
-- accurate up to some epsilon.
32-
it "Should take cubed root" $ do
33-
allTrueAll ((abs (3 - cbrt @Double 27)) `lt` 1.0e-14) `shouldBe` (1, 0)
34-
it "Should take square root" $ do
35-
allTrueAll ((abs (2 - sqrt @Double 4)) `lt` 1.0e-14) `shouldBe` (1, 0)
92+
matrix @Int (2, 2) [[1, 1], [1, 1]] + matrix @Int (2, 2) [[1, 1], [1, 1]]
93+
`shouldBe` matrix @Int (2, 2) [[2, 2], [2, 2]]
94+
prop "Should take cubed root" $ \(x :: Double) ->
95+
evalf (ArrayFire.cbrt (scalar (x * x * x))) `shouldBeApprox` x
3696

3797
it "Should lte Array" $ do
38-
2 `le` (3 :: Array Double) `shouldBe` 1
98+
2 `ArrayFire.le` (3 :: Array Double) `shouldBe` 1
3999
it "Should gte Array" $ do
40-
2 `ge` (3 :: Array Double) `shouldBe` 0
100+
2 `ArrayFire.ge` (3 :: Array Double) `shouldBe` 0
41101
it "Should gt Array" $ do
42-
2 `gt` (3 :: Array Double) `shouldBe` 0
102+
2 `ArrayFire.gt` (3 :: Array Double) `shouldBe` 0
43103
it "Should lt Array" $ do
44-
2 `le` (3 :: Array Double) `shouldBe` 1
104+
2 `ArrayFire.le` (3 :: Array Double) `shouldBe` 1
45105
it "Should eq Array" $ do
46106
3 == (3 :: Array Double) `shouldBe` True
47107
it "Should and Array" $ do
48-
(mkArray @CBool [1] [0] `and` mkArray [1] [1])
49-
`shouldBe` mkArray [1] [0]
108+
(mkArray @CBool [1] [0] `ArrayFire.and` mkArray [1] [1])
109+
`shouldBe` mkArray [1] [0]
50110
it "Should and Array" $ do
51-
(mkArray @CBool [2] [0,0] `and` mkArray [2] [1,0])
52-
`shouldBe` mkArray [2] [0, 0]
111+
(mkArray @CBool [2] [0, 0] `ArrayFire.and` mkArray [2] [1, 0])
112+
`shouldBe` mkArray [2] [0, 0]
53113
it "Should or Array" $ do
54-
(mkArray @CBool [2] [0,0] `or` mkArray [2] [1,0])
55-
`shouldBe` mkArray [2] [1, 0]
114+
(mkArray @CBool [2] [0, 0] `ArrayFire.or` mkArray [2] [1, 0])
115+
`shouldBe` mkArray [2] [1, 0]
56116
it "Should not Array" $ do
57-
not (mkArray @CBool [2] [1,0]) `shouldBe` mkArray [2] [0,1]
117+
ArrayFire.not (mkArray @CBool [2] [1, 0]) `shouldBe` mkArray [2] [0, 1]
58118
it "Should bitwise and array" $ do
59-
bitAnd (scalar @Int 1) (scalar @Int 0)
60-
`shouldBe`
61-
0
119+
ArrayFire.bitAnd (scalar @Int 1) (scalar @Int 0)
120+
`shouldBe` 0
62121
it "Should bitwise or array" $ do
63-
bitOr (scalar @Int 1) (scalar @Int 0)
64-
`shouldBe`
65-
1
122+
ArrayFire.bitOr (scalar @Int 1) (scalar @Int 0)
123+
`shouldBe` 1
66124
it "Should bitwise xor array" $ do
67-
bitXor (scalar @Int 1) (scalar @Int 1)
68-
`shouldBe`
69-
0
125+
ArrayFire.bitXor (scalar @Int 1) (scalar @Int 1)
126+
`shouldBe` 0
70127
it "Should bitwise shift left an array" $ do
71-
bitShiftL (scalar @Int 1) (scalar @Int 3)
72-
`shouldBe`
73-
8
128+
ArrayFire.bitShiftL (scalar @Int 1) (scalar @Int 3)
129+
`shouldBe` 8
74130
it "Should cast an array" $ do
75131
getType (cast (scalar @Int 1) :: Array Double)
76-
`shouldBe`
77-
F64
132+
`shouldBe` ArrayFire.F64
78133
it "Should find the minimum of two arrays" $ do
79134
minOf (scalar @Int 1) (scalar @Int 0)
80-
`shouldBe`
81-
0
135+
`shouldBe` 0
82136
it "Should find the max of two arrays" $ do
83137
maxOf (scalar @Int 1) (scalar @Int 0)
84-
`shouldBe`
85-
1
138+
`shouldBe` 1
86139
it "Should take the clamp of 3 arrays" $ do
87140
clamp (scalar @Int 2) (scalar @Int 1) (scalar @Int 3)
88-
`shouldBe`
89-
2
141+
`shouldBe` 2
90142
it "Should check if an array has positive or negative infinities" $ do
91143
isInf (scalar @Double (1 / 0)) `shouldBe` scalar @Double 1
92144
isInf (scalar @Double 10) `shouldBe` scalar @Double 0
93145
it "Should check if an array has any NaN values" $ do
94-
isNaN (scalar @Double (acos 2)) `shouldBe` scalar @Double 1
95-
isNaN (scalar @Double 10) `shouldBe` scalar @Double 0
146+
ArrayFire.isNaN (scalar @Double (acos 2)) `shouldBe` scalar @Double 1
147+
ArrayFire.isNaN (scalar @Double 10) `shouldBe` scalar @Double 0
96148
it "Should check if an array has any Zero values" $ do
97149
isZero (scalar @Double (acos 2)) `shouldBe` scalar @Double 0
98150
isZero (scalar @Double 0) `shouldBe` scalar @Double 1
99151
isZero (scalar @Double 1) `shouldBe` scalar @Double 0
152+
153+
prop "Floating @Float (exp)" $ \(x :: Float) -> exp `shouldMatchBuiltin` exp $ x
154+
prop "Floating @Float (log)" $ \(x :: Float) -> log `shouldMatchBuiltin` log $ x
155+
prop "Floating @Float (sqrt)" $ \(x :: Float) -> sqrt `shouldMatchBuiltin` sqrt $ x
156+
prop "Floating @Float (**)" $ \(x :: Float) (y :: Float) -> ((**) `shouldMatchBuiltin2` (**)) x y
157+
prop "Floating @Float (sin)" $ \(x :: Float) -> sin `shouldMatchBuiltin` sin $ x
158+
prop "Floating @Float (cos)" $ \(x :: Float) -> cos `shouldMatchBuiltin` cos $ x
159+
prop "Floating @Float (tan)" $ \(x :: Float) -> tan `shouldMatchBuiltin` tan $ x
160+
prop "Floating @Float (asin)" $ \(x :: Float) -> asin `shouldMatchBuiltin` asin $ x
161+
prop "Floating @Float (acos)" $ \(x :: Float) -> acos `shouldMatchBuiltin` acos $ x
162+
prop "Floating @Float (atan)" $ \(x :: Float) -> atan `shouldMatchBuiltin` atan $ x
163+
prop "Floating @Float (sinh)" $ \(x :: Float) -> sinh `shouldMatchBuiltin` sinh $ x
164+
prop "Floating @Float (cosh)" $ \(x :: Float) -> cosh `shouldMatchBuiltin` cosh $ x
165+
prop "Floating @Float (tanh)" $ \(x :: Float) -> tanh `shouldMatchBuiltin` tanh $ x
166+
prop "Floating @Float (asinh)" $ \(x :: Float) -> asinh `shouldMatchBuiltin` asinh $ x
167+
prop "Floating @Float (acosh)" $ \(x :: Float) -> acosh `shouldMatchBuiltin` acosh $ x
168+
prop "Floating @Float (atanh)" $ \(x :: Float) -> atanh `shouldMatchBuiltin` atanh $ x

0 commit comments

Comments
 (0)