Skip to content

Commit

Permalink
Add saturation test mentioned in egg's #60 discussion
Browse files Browse the repository at this point in the history
  • Loading branch information
alt-romes committed Sep 23, 2022
1 parent d7b6a8f commit 019167f
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 1 deletion.
2 changes: 1 addition & 1 deletion hegg.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ test-suite hegg-test
hs-source-dirs: test
main-is: Test.hs
other-modules: Invariants, Sym, Lambda, SimpleSym,
T1
T1, T2

other-extensions: OverloadedStrings
build-depends: base,
Expand Down
55 changes: 55 additions & 0 deletions test/T2.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE OverloadedStrings #-}
module T2 where

-- Tests whether this saturates just like mwillsey claims that it does in egg!

import Prelude hiding (not)

import Test.Tasty.HUnit
import Data.Deriving
import Data.Equality.Matching
import Data.Equality.Analysis
import Data.Equality.Language
import Data.Equality.Extraction
import Data.Equality.Saturation

data Lang a = And a a
| Or a a
| Not a
| ToElim a
| Sym Int
deriving (Functor, Foldable, Traversable)

deriveEq1 ''Lang
deriveOrd1 ''Lang
deriveShow1 ''Lang

instance Analysis Lang where
type Domain Lang = ()
makeA _ _ = ()
joinA = (<>)

instance Language Lang

x, y :: Pattern Lang
x = "x"
y = "y"
not :: Pattern Lang -> Pattern Lang
not = pat . Not

rules :: [Rewrite Lang]
rules =
[ pat (x `And` y) := not (pat (not x `Or` not y))
, pat (x `Or` y) := not (pat (not x `And` not y))
, not (not x) := pat (ToElim x)
, pat (ToElim x) := x
]

main :: IO ()
main = do
fst (equalitySaturation (Fix $ (Fix $ Not $ Fix $ Sym 0) `And` (Fix $ Not $ Fix $ Sym 1)) rules depthCost) @?= Fix (Not $ Fix $ (Fix $ Sym 0) `Or` (Fix $ Sym 1))
pure ()

2 changes: 2 additions & 0 deletions test/Test.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import Lambda
import SimpleSym

import qualified T1
import qualified T2

tests :: TestTree
tests = testGroup "Tests"
Expand All @@ -20,6 +21,7 @@ tests = testGroup "Tests"
, simpleSymTests
, invariants
, testCase "T1" (T1.main `catch` (\(e :: SomeException) -> assertFailure (show e)))
, testCase "T2" (T2.main `catch` (\(e :: SomeException) -> assertFailure (show e)))
]

main :: IO ()
Expand Down

0 comments on commit 019167f

Please sign in to comment.