diff --git a/src/Data/Equality/Extraction.hs b/src/Data/Equality/Extraction.hs index 2610cd2..28b368a 100644 --- a/src/Data/Equality/Extraction.hs +++ b/src/Data/Equality/Extraction.hs @@ -51,7 +51,7 @@ extractBest :: forall lang cost -> CostFunction lang cost -- ^ The cost function to define /best/ -> ClassId -- ^ The e-class from which we'll extract the expression -> Fix lang -- ^ The resulting /best/ expression, in its fixed point form. -extractBest egr cost (flip find egr -> i) = +extractBest (rebuild -> egr) cost (flip unsafeFind egr -> i) = -- Use `egg`s strategy of find costs for all possible classes and then just -- picking up the best from the target e-class. In practice this shouldn't @@ -106,7 +106,7 @@ extractBest egr cost (flip find egr -> i) = -- with its cost nodeTotalCost :: Traversable lang => ClassIdMap (CostWithExpr lang cost) -> ENode lang -> Maybe (CostWithExpr lang cost) nodeTotalCost m (Node n) = do - expr <- traverse ((`IM.lookup` m) . flip find egr) n + expr <- traverse ((`IM.lookup` m) . flip unsafeFind egr) n return $ CostWithExpr (cost ((fst . unCWE) <$> expr), (Fix $ (snd . unCWE) <$> expr)) {-# INLINE nodeTotalCost #-} {-# INLINABLE extractBest #-} diff --git a/src/Data/Equality/Graph.hs b/src/Data/Equality/Graph.hs index 1e214d0..94ea5e6 100644 --- a/src/Data/Equality/Graph.hs +++ b/src/Data/Equality/Graph.hs @@ -25,6 +25,7 @@ module Data.Equality.Graph -- ** Querying , find, canonicalize + , unsafeFind, unsafeCanonicalize -- * Re-exports , module Data.Equality.Graph.Classes @@ -64,10 +65,10 @@ import Data.Equality.Graph.Lens -- class it's already represented in will be returned. add :: forall l. Language l => ENode l -> EGraph l -> (ClassId, EGraph l) add uncanon_e egr = - let !new_en = canonicalize uncanon_e egr + let !new_en = unsafeCanonicalize uncanon_e egr in case lookupNM new_en (memo egr) of - Just canon_enode_id -> (find canon_enode_id egr, egr) + Just canon_enode_id -> (canon_enode_id, egr) Nothing -> let @@ -140,8 +141,8 @@ merge a b egr0 = -- Use canonical ids let - a' = find a egr0 - b' = find b egr0 + a' = unsafeFind a egr0 + b' = unsafeFind b egr0 in if a' == b' then (a', egr0) @@ -219,16 +220,17 @@ rebuild (EGraph uf cls mm wl awl) = let emptiedEgr = (EGraph uf cls mm mempty mempty) - wl' = nubOrd $ bimap (`find` emptiedEgr) (`canonicalize` emptiedEgr) <$> wl + wl' = nubOrd $ bimap (`unsafeFind` emptiedEgr) (`unsafeCanonicalize` emptiedEgr) <$> wl egr' = foldr repair emptiedEgr wl' awl' = nubIntOn fst $ first (`find` egr') <$> awl egr'' = foldr repairAnal egr' awl' - in - -- Loop until worklist is completely empty - if null (worklist egr'') && null (analysisWorklist egr'') + in -- trace ("Normal deduped: " <> show (length wl - length wl') <> ". Anal deduped: " <> show (length awl - length awl')) $! + + -- Loop until worklist is completely empty + if null (worklist egr'') && null (analysisWorklist egr'') then egr'' - else rebuild egr'' -- ROMES:TODO: Doesn't seem to be needed at all in the testsuite. + else rebuild egr'' {-# INLINEABLE rebuild #-} -- ROMES:TODO: find repair_id could be shared between repair and repairAnal? @@ -272,17 +274,57 @@ repairAnal (repair_id, node) egr = -- that their e-class ids are represented by the same e-class canonical ids -- -- canonicalize(𝑓(𝑎,𝑏,𝑐,...)) = 𝑓((find 𝑎), (find 𝑏), (find 𝑐),...) -canonicalize :: Functor l => ENode l -> EGraph l -> ENode l -canonicalize (Node enode) eg = Node $ fmap (`find` eg) enode +-- +-- This operation will force the e-graph to be rebuilt since canonicalizing an +-- e-node only makes sense on an e-graph in a sound state. See also +-- 'unsafeCanonicalize'. +canonicalize :: Language l => ENode l -> EGraph l -> ENode l +canonicalize n = unsafeCanonicalize n . rebuild {-# INLINE canonicalize #-} -- | Find the canonical representation of an e-class id in the e-graph +-- +-- This operation will force the e-graph to be rebuilt since finding the +-- canonical representation of an e-class id only makes sense on an e-graph in +-- a sound state. See also 'unsafeFind'. +-- -- Invariant: The e-class id always exists. -find :: ClassId -> EGraph l -> ClassId -find cid = findRepr cid . unionFind +find :: Language l => ClassId -> EGraph l -> ClassId +find cid = unsafeFind cid . rebuild {-# INLINE find #-} -- | The empty e-graph. Nothing is represented in it yet. emptyEGraph :: Language l => EGraph l emptyEGraph = EGraph emptyUF mempty mempty mempty mempty {-# INLINE emptyEGraph #-} + +-- | Like 'canonicalize' but doesn't force a rebuild. +-- +-- It's the responsibility of the caller to ensure that the e-graph is in a +-- sound state, or otherwise know that the result might not represent the +-- /true/ representative (that one could expect given e.g. congruence), because +-- until rebuilt, the e-graph invariants aren't maintained. +-- +-- By using the unsafe variant one might avoid unncessary calls to 'rebuild', +-- but note that if the e-graph is already built calls to 'rebuild' are +-- almost instantaneous. +unsafeCanonicalize :: Functor l => ENode l -> EGraph l -> ENode l +unsafeCanonicalize (Node enode) eg = Node $ fmap (`unsafeFind` eg) enode +{-# INLINE unsafeCanonicalize #-} + +-- | Like 'find' but doesn't force a rebuild. +-- +-- It's the responsibility of the caller to ensure that the e-graph is in a +-- sound state, or otherwise know that the result might not represent the +-- /true/ representative (that one could expect given e.g. congruence), because +-- until rebuilt, the e-graph invariants aren't maintained. +-- +-- By using the unsafe variant one might avoid unncessary calls to 'rebuild', +-- but note that if the e-graph is already built calls to 'rebuild' are +-- almost instantaneous. +-- +-- Invariant: The e-class id always exists. +unsafeFind :: ClassId -> EGraph l -> ClassId +unsafeFind cid = findRepr cid . unionFind +{-# INLINE unsafeFind #-} +