Skip to content

Commit

Permalink
Make e-graph abstract
Browse files Browse the repository at this point in the history
  • Loading branch information
alt-romes committed Aug 28, 2022
1 parent 862d55a commit f87c05f
Show file tree
Hide file tree
Showing 12 changed files with 91 additions and 78 deletions.
9 changes: 7 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Revision history for hsym
# Revision history for hegg

## 0.1.0.0 -- YYYY-mm-dd
## Unreleased

* Make e-graph abstract. The internal structure can still be modified to some
extent using Data.Equality.Graph.Lens

## 0.1.0.0 -- 2022-08-25

* First version. Released on an unsuspecting world.
1 change: 1 addition & 0 deletions hegg.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ library
-- -dsuppress-var-kinds

exposed-modules: Data.Equality.Graph,
Data.Equality.Graph.Internal,
Data.Equality.Graph.ReprUnionFind,
Data.Equality.Graph.Classes,
Data.Equality.Graph.Classes.Id,
Expand Down
2 changes: 1 addition & 1 deletion src/Data/Equality/Analysis.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import Data.Kind (Type)
import Data.Equality.Graph.Classes.Id
import Data.Equality.Graph.Nodes

import {-# SOURCE #-} Data.Equality.Graph (EGraph)
import {-# SOURCE #-} Data.Equality.Graph.Internal (EGraph)

-- | The e-class analysis defined for a language @l@.
class Eq (Domain l) => Analysis (l :: Type -> Type) where
Expand Down
7 changes: 4 additions & 3 deletions src/Data/Equality/Extraction.hs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import qualified Data.IntMap.Strict as IM

import Data.Equality.Utils
import Data.Equality.Graph
import Data.Equality.Graph.Lens

-- vvvv and necessarily all the best sub-expressions from children equilalence classes

Expand All @@ -50,13 +51,13 @@ extractBest :: forall lang. Language lang
-> CostFunction lang -- ^ The cost function to define /best/
-> ClassId -- ^ The e-class from which we'll extract the expression
-> Fix lang -- ^ The resulting /best/ expression, in its fixed point form.
extractBest g@EGraph{classes = eclasses'} cost (flip find g -> i) =
extractBest egr cost (flip find egr -> i) =

-- Use `egg`s strategy of find costs for all possible classes and then just
-- picking up the best from the target e-class. In practice this shouldn't
-- find the cost of unused nodes because the "topmost" e-class will be the
-- target, and all sub-classes must be calculated?
let allCosts = findCosts eclasses' mempty
let allCosts = findCosts (egr^._classes) mempty

in case findBest i allCosts of
Just (CostWithExpr (_,n)) -> n
Expand Down Expand Up @@ -106,7 +107,7 @@ extractBest g@EGraph{classes = eclasses'} cost (flip find g -> i) =
-- with its cost
nodeTotalCost :: Traversable lang => ClassIdMap (CostWithExpr lang) -> ENode lang -> Maybe (CostWithExpr lang)
nodeTotalCost m (Node n) = do
expr <- traverse ((`IM.lookup` m) . flip find g) n
expr <- traverse ((`IM.lookup` m) . flip find egr) n
return $ CostWithExpr (cost ((fst . unCWE) <$> expr), (Fix $ (snd . unCWE) <$> expr))
{-# INLINE nodeTotalCost #-}

Expand Down
37 changes: 3 additions & 34 deletions src/Data/Equality/Graph.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
{-# LANGUAGE TupleSections #-}
-- {-# LANGUAGE ApplicativeDo #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE UndecidableInstances #-} -- tmp show
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
Expand All @@ -15,9 +14,7 @@
module Data.Equality.Graph
(
-- * Definition of e-graph
EGraph(..)

, Memo, Worklist
EGraph

-- * Functions on e-graphs
, emptyEGraph
Expand All @@ -39,51 +36,23 @@ module Data.Equality.Graph

import Data.Function

import Data.Functor.Classes

import qualified Data.IntMap.Strict as IM
import qualified Data.Set as S

import Data.Equality.Graph.Internal
import Data.Equality.Graph.ReprUnionFind
import Data.Equality.Graph.Classes
import Data.Equality.Graph.Nodes
import Data.Equality.Analysis
import Data.Equality.Language
import Data.Equality.Graph.Lens

-- | E-graph representing terms of language @l@.
--
-- Intuitively, an e-graph is a set of equivalence classes (e-classes). Each e-class is a
-- set of e-nodes representing equivalent terms from a given language, and an e-node is a function
-- symbol paired with a list of children e-classes.
data EGraph l = EGraph
{ unionFind :: !ReprUnionFind -- ^ Union find like structure to find canonical representation of an e-class id
, classes :: !(ClassIdMap (EClass l)) -- ^ Map canonical e-class ids to their e-classes
, memo :: !(Memo l) -- ^ Hashcons maps all canonical e-nodes to their e-class ids
, worklist :: !(Worklist l) -- ^ Worklist of e-class ids that need to be upward merged
, analysisWorklist :: !(Worklist l) -- ^ Like 'worklist' but for analysis repairing
}

-- | The hashcons 𝐻 is a map from e-nodes to e-class ids
type Memo l = NodeMap l ClassId

-- | Maintained worklist of e-class ids that need to be β€œupward merged”
type Worklist l = NodeMap l ClassId

-- ROMES:TODO: join things built in paralell?
-- instance Ord1 l => Semigroup (EGraph l) where
-- (<>) eg1 eg2 = undefined -- not so easy
-- instance Ord1 l => Monoid (EGraph l) where
-- mempty = EGraph emptyUF mempty mempty mempty

instance (Show (Domain l), Show1 l) => Show (EGraph l) where
show (EGraph a b c d e) =
"UnionFind: " <> show a <>
"\n\nE-Classes: " <> show b <>
"\n\nHashcons: " <> show c <>
"\n\nWorklist: " <> show d <>
"\n\nAnalWorklist: " <> show e


-- | Add an e-node to the e-graph
--
Expand Down Expand Up @@ -234,7 +203,7 @@ merge a b egr0 =
{-# SCC merge #-}


-- | The rebuild operation processes the e-graph's current 'Worklist',
-- | The rebuild operation processes the e-graph's current worklist,
-- restoring the invariants of deduplication and congruence. Rebuilding is
-- similar to other approaches in how it restores congruence; but it uniquely
-- allows the client to choose when to restore invariants in the context of a
Expand Down
22 changes: 0 additions & 22 deletions src/Data/Equality/Graph.hs-boot

This file was deleted.

41 changes: 41 additions & 0 deletions src/Data/Equality/Graph/Internal.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{-# LANGUAGE UndecidableInstances #-} -- tmp show
{-# LANGUAGE FlexibleContexts #-}
{-# OPTIONS_HADDOCK hide #-}
{-|
Non-abstract definition of e-graphs
-}
module Data.Equality.Graph.Internal where

import Data.Functor.Classes

import Data.Equality.Graph.ReprUnionFind
import Data.Equality.Graph.Classes
import Data.Equality.Graph.Nodes
import Data.Equality.Analysis

-- | E-graph representing terms of language @l@.
--
-- Intuitively, an e-graph is a set of equivalence classes (e-classes). Each e-class is a
-- set of e-nodes representing equivalent terms from a given language, and an e-node is a function
-- symbol paired with a list of children e-classes.
data EGraph l = EGraph
{ unionFind :: !ReprUnionFind -- ^ Union find like structure to find canonical representation of an e-class id
, classes :: !(ClassIdMap (EClass l)) -- ^ Map canonical e-class ids to their e-classes
, memo :: !(Memo l) -- ^ Hashcons maps all canonical e-nodes to their e-class ids
, worklist :: !(Worklist l) -- ^ Worklist of e-class ids that need to be upward merged
, analysisWorklist :: !(Worklist l) -- ^ Like 'worklist' but for analysis repairing
}

-- | The hashcons 𝐻 is a map from e-nodes to e-class ids
type Memo l = NodeMap l ClassId

-- | Maintained worklist of e-class ids that need to be β€œupward merged”
type Worklist l = NodeMap l ClassId

instance (Show (Domain l), Show1 l) => Show (EGraph l) where
show (EGraph a b c d e) =
"UnionFind: " <> show a <>
"\n\nE-Classes: " <> show b <>
"\n\nHashcons: " <> show c <>
"\n\nWorklist: " <> show d <>
"\n\nAnalWorklist: " <> show e
9 changes: 9 additions & 0 deletions src/Data/Equality/Graph/Internal.hs-boot
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE StandaloneKindSignatures #-}
module Data.Equality.Graph.Internal where

import Data.Kind

type EGraph :: (Type -> Type) -> Type
type role EGraph nominal
data EGraph l
16 changes: 10 additions & 6 deletions src/Data/Equality/Graph/Lens.hs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE Rank2Types #-}
{-|
Hand-rolled lenses on e-graphs and e-classes which come in quite handy and
are heavily used in 'Data.Equality.Graph'.
Hand-rolled lenses on e-graphs and e-classes which come in quite handy, are
heavily used in 'Data.Equality.Graph', and are the only exported way of
editing the structure of the e-graph. If you want to write some complex
'Analysis' you'll probably need these.
-}
module Data.Equality.Graph.Lens where

Expand All @@ -12,11 +14,12 @@ import qualified Data.Set as S
import Data.Functor.Identity
import Data.Functor.Const

import Data.Equality.Graph.Internal
import Data.Equality.Graph.Classes.Id
import Data.Equality.Graph.Nodes
import Data.Equality.Graph.Classes
import Data.Equality.Graph.ReprUnionFind
import Data.Equality.Analysis
import {-# SOURCE #-} Data.Equality.Graph (EGraph(..), Memo, find)

-- | A 'Lens'' as defined in other lenses libraries
type Lens' s a = forall f. Functor f => (a -> f a) -> (s -> f s)
Expand All @@ -37,12 +40,13 @@ type Lens' s a = forall f. Functor f => (a -> f a) -> (s -> f s)
-- Calls 'error' when the e-class doesn't exist
_class :: ClassId -> Lens' (EGraph l) (EClass l)
_class i afa s =
let canon_id = find i s
let canon_id = findRepr i (unionFind s)
in (\c' -> s { classes = IM.insert canon_id c' (classes s) }) <$> afa (classes s IM.! canon_id)
{-# INLINE _classΒ #-}

-- | Lens for the 'Memo' of e-nodes in an e-graph
_memo :: Lens' (EGraph l) (Memo l)
-- | Lens for the memo of e-nodes in an e-graph, that is, a mapping from
-- e-nodes to the e-class they're represented in
_memo :: Lens' (EGraph l) (NodeMap l ClassId)
_memo afa egr = (\m1 -> egr {memo = m1}) <$> afa (memo egr)
{-# INLINE _memoΒ #-}

Expand Down
3 changes: 2 additions & 1 deletion src/Data/Equality/Matching.hs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import qualified Data.IntMap.Strict as IM
import qualified Data.IntSet as IS

import Data.Equality.Graph
import Data.Equality.Graph.Lens
import Data.Equality.Matching.Database
import Data.Equality.Matching.Pattern

Expand Down Expand Up @@ -69,7 +70,7 @@ ematch db patr =

-- | Convert an e-graph into a database
eGraphToDatabase :: Language l => EGraph l -> Database l
eGraphToDatabase EGraph{..} = foldrWithKeyNM' addENodeToDB (DB mempty) memo
eGraphToDatabase egr = foldrWithKeyNM' addENodeToDB (DB mempty) (egr^._memo)
where

-- Add an enode in an e-graph, given its class, to a database
Expand Down
11 changes: 7 additions & 4 deletions src/Data/Equality/Saturation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ import Control.Monad
import Data.Proxy

import Data.Equality.Utils
import Data.Equality.Graph.Nodes
import Data.Equality.Graph.Lens
import qualified Data.Equality.Graph as G
import Data.Equality.Graph.Monad
import Data.Equality.Language
Expand Down Expand Up @@ -102,9 +104,10 @@ equalitySaturation' _ expr rewrites cost = egraph $ do
equalitySaturation'' 30 _ = return () -- Stop after X iterations
equalitySaturation'' i stats = do

egr@G.EGraph{ G.memo = beforeMemo, G.classes = beforeClasses } <- get
egr <- get

let db = eGraphToDatabase egr
let (beforeMemo, beforeClasses) = (egr^._memo, egr^._classes)
db = eGraphToDatabase egr

-- Read-only phase, invariants are preserved
-- With backoff scheduler
Expand All @@ -117,7 +120,7 @@ equalitySaturation' _ expr rewrites cost = egraph $ do
-- Restore the invariants once per iteration
rebuild

G.EGraph { G.memo = afterMemo, G.classes = afterClasses } <- get
(afterMemo, afterClasses) <- gets (\g -> (g^._memo, g^._classes))

-- ROMES:TODO: Node limit...
-- ROMES:TODO: Actual Timeout... not just iteration timeout
Expand Down Expand Up @@ -177,7 +180,7 @@ equalitySaturation' _ expr rewrites cost = egraph $ do

-- | Represent a pattern in the e-graph a pattern given substitions
reprPat :: Subst -> l (Pattern l) -> EGraphM l ClassId
reprPat subst = add . G.Node <=< traverse \case
reprPat subst = add . Node <=< traverse \case
VariablePattern v ->
case IM.lookup v subst of
Nothing -> error "impossible: couldn't find v in subst?"
Expand Down
11 changes: 6 additions & 5 deletions test/Invariants.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import qualified Data.Set as S
import qualified Data.IntMap.Strict as IM

import Data.Equality.Graph.Monad as GM
import Data.Equality.Graph.Lens
import Data.Equality.Graph
import Data.Equality.Analysis
import Data.Equality.Extraction
Expand Down Expand Up @@ -52,7 +53,7 @@ instance Analysis SimpleExpr where
patFoldAllClasses :: forall l. (Language l, Num (Pattern l))
=> Fix l -> Integer -> Bool
patFoldAllClasses expr i =
case IM.toList $ classes eg of
case IM.toList $ (eg^._classes) of
[_] -> True
_ -> False
where
Expand Down Expand Up @@ -97,7 +98,7 @@ ematchSingletonVar v eg =
let
db = eGraphToDatabase eg
matches = S.fromList $ map matchClassId $ ematch db (VariablePattern v)
eclasses = S.fromList $ map fst $ IM.toList $ classes eg
eclasses = S.fromList $ map fst $ IM.toList (eg^._classes)
in
matches == eclasses

Expand All @@ -122,13 +123,13 @@ ematchSingletonVar v eg =
-- ROMES:TODO Should I rebuild it here? Then the property test is that after rebuilding ...HashConsInvariant
hashConsInvariant :: forall l. Language l
=> EGraph l -> Bool
hashConsInvariant eg@EGraph{..} =
all f (IM.toList classes)
hashConsInvariant eg =
all f (IM.toList (eg^._classes))
where
-- e-node 𝑛 ∈ 𝑀 [π‘Ž] ⇐⇒ 𝐻 [canonicalize(𝑛)] = find(π‘Ž)
f (i, EClass _ nodes _ _) = all g nodes
where
g en = case lookupNM (canonicalize en eg) memo of
g en = case lookupNM (canonicalize en eg) (eg^._memo) of
Nothing -> error "how can we not find canonical thing in map? :)" -- False
Just i' -> i' == find i eg

Expand Down

0 comments on commit f87c05f

Please sign in to comment.