From 2c1d08cb5a67e13c6245f3aa4d0e5617fa66511c Mon Sep 17 00:00:00 2001 From: Sergey Vinokurov Date: Sat, 14 Oct 2023 10:55:02 +0100 Subject: [PATCH 1/4] =?UTF-8?q?Introduce=20=E2=80=98partitionKeys=E2=80=99?= =?UTF-8?q?=20that=20fuses=20=E2=80=98restrictKeys=E2=80=99=20and=20?= =?UTF-8?q?=E2=80=98withoutKeys=E2=80=99=20in=20one=20go?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- containers-tests/benchmarks/Map.hs | 12 +++++- containers-tests/tests/map-properties.hs | 7 ++++ containers/src/Data/Map/Internal.hs | 44 ++++++++++++++++++++++ containers/src/Data/Map/Lazy.hs | 1 + containers/src/Data/Map/Strict.hs | 1 + containers/src/Data/Map/Strict/Internal.hs | 4 +- 6 files changed, 67 insertions(+), 2 deletions(-) diff --git a/containers-tests/benchmarks/Map.hs b/containers-tests/benchmarks/Map.hs index 0e324e556..4cbe36577 100644 --- a/containers-tests/benchmarks/Map.hs +++ b/containers-tests/benchmarks/Map.hs @@ -5,7 +5,7 @@ module Main where import Control.Applicative (Const(Const, getConst), pure) import Control.DeepSeq (rnf) import Control.Exception (evaluate) -import Test.Tasty.Bench (bench, defaultMain, whnf, nf) +import Test.Tasty.Bench (bench, defaultMain, whnf, nf, bcompare) import Data.Functor.Identity (Identity(..)) import Data.List (foldl') import qualified Data.Map as M @@ -15,13 +15,16 @@ import Data.Maybe (fromMaybe) import Data.Functor ((<$)) import Data.Coerce import Prelude hiding (lookup) +import Utils.Containers.Internal.StrictPair main = do let m = M.fromAscList elems :: M.Map Int Int m_even = M.fromAscList elems_even :: M.Map Int Int m_odd = M.fromAscList elems_odd :: M.Map Int Int + m_odd_keys = M.keysSet m_odd evaluate $ rnf [m, m_even, m_odd] evaluate $ rnf elems_rev + evaluate $ rnf m_odd_keys defaultMain [ bench "lookup absent" $ whnf (lookup evens) m_odd , bench "lookup present" $ whnf (lookup evens) m_even @@ -95,8 +98,15 @@ main = do , bench "fromDistinctDescList" $ whnf M.fromDistinctDescList elems_rev , bench "fromDistinctDescList:fusion" $ whnf (\n -> M.fromDistinctDescList [(i,i) | i <- [n,n-1..1]]) bound , bench "minView" $ whnf (\m' -> case M.minViewWithKey m' of {Nothing -> 0; Just ((k,v),m'') -> k+v+M.size m''}) (M.fromAscList $ zip [1..10::Int] [100..110::Int]) + , bench "eq" $ whnf (\m' -> m' == m') m -- worst case, compares everything , bench "compare" $ whnf (\m' -> compare m' m') m -- worst case, compares everything + + , bench "restrictKeys+withoutKeys" + $ whnf (\ks -> M.restrictKeys m ks :*: M.withoutKeys m ks) m_odd_keys + , bcompare "/restrictKeys+withoutKeys/" + $ bench "partitionKeys" + $ whnf (M.partitionKeys m) m_odd_keys ] where bound = 2^12 diff --git a/containers-tests/tests/map-properties.hs b/containers-tests/tests/map-properties.hs index 6b7a45ad5..d6eea949e 100644 --- a/containers-tests/tests/map-properties.hs +++ b/containers-tests/tests/map-properties.hs @@ -173,6 +173,7 @@ main = defaultMain $ testGroup "map-properties" , testProperty "withoutKeys" prop_withoutKeys , testProperty "intersection" prop_intersection , testProperty "restrictKeys" prop_restrictKeys + , testProperty "partitionKeys" prop_partitionKeys , testProperty "intersection model" prop_intersectionModel , testProperty "intersectionWith" prop_intersectionWith , testProperty "intersectionWithModel" prop_intersectionWithModel @@ -1140,6 +1141,12 @@ prop_withoutKeys m s0 = valid reduced .&&. (m `withoutKeys` s === filterWithKey s = keysSet s0 reduced = withoutKeys m s +prop_partitionKeys :: IMap -> IMap -> Property +prop_partitionKeys m s0 = valid with .&&. valid without .&&. (m `partitionKeys` s === (m `restrictKeys` s, m `withoutKeys` s)) + where + s = keysSet s0 + (with, without) = partitionKeys m s + prop_intersection :: IMap -> IMap -> Bool prop_intersection t1 t2 = valid (intersection t1 t2) diff --git a/containers/src/Data/Map/Internal.hs b/containers/src/Data/Map/Internal.hs index b230a574e..2baace297 100644 --- a/containers/src/Data/Map/Internal.hs +++ b/containers/src/Data/Map/Internal.hs @@ -7,6 +7,7 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE Trustworthy #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE ScopedTypeVariables #-} #define USE_MAGIC_PROXY 1 #endif @@ -299,6 +300,7 @@ module Data.Map.Internal ( , restrictKeys , withoutKeys + , partitionKeys , partition , partitionWithKey @@ -1966,6 +1968,48 @@ withoutKeys m (Set.Bin _ k ls rs) = case splitMember k m of {-# INLINABLE withoutKeys #-} #endif +-- | \(O\bigl(m \log\bigl(\frac{n}{m}+1\bigr)\bigr), \; 0 < m \leq n\). Partition the map according to a set. +-- The first map contains the input 'Map' restricted to those keys found in the 'Set', +-- the second map contains the input 'Map' without all keys in the 'Set'. +-- This is more efficient than using 'restrictKeys' and 'withoutKeys' together. +-- +-- @ +-- m \`partitionKeys\` s = (m ``restrictKeys`` s, m ``withoutKeys`` s) +-- @ +partitionKeys :: forall k a. Ord k => Map k a -> Set k -> (Map k a, Map k a) +partitionKeys xs ys = + case go xs ys of + xs' :*: ys' -> (xs', ys') + where + go :: Map k a -> Set k -> StrictPair (Map k a) (Map k a) + go Tip _ = Tip :*: Tip + go m Set.Tip = Tip :*: m + go m@(Bin _ k x lm rm) s@Set.Bin{} = + case b of + True -> with :*: without + where + with = + if lmWith `ptrEq` lm && rmWith `ptrEq` rm + then m + else link k x lmWith rmWith + without = + link2 lmWithout rmWithout + False -> with :*: without + where + with = link2 lmWith rmWith + without = + if lmWithout `ptrEq` lm && rmWithout `ptrEq` rm + then m + else link k x lmWithout rmWithout + where + !(lmWith :*: lmWithout) = go lm ls' + !(rmWith :*: rmWithout) = go rm rs' + + !(!ls', b, !rs') = Set.splitMember k s +#if __GLASGOW_HASKELL__ +{-# INLINABLE partitionKeys #-} +#endif + -- | \(O(n+m)\). Difference with a combining function. -- When two equal keys are -- encountered, the combining function is applied to the values of these keys. diff --git a/containers/src/Data/Map/Lazy.hs b/containers/src/Data/Map/Lazy.hs index 2fca4c91d..d69e943f5 100644 --- a/containers/src/Data/Map/Lazy.hs +++ b/containers/src/Data/Map/Lazy.hs @@ -231,6 +231,7 @@ module Data.Map.Lazy ( , filterWithKey , restrictKeys , withoutKeys + , partitionKeys , partition , partitionWithKey , takeWhileAntitone diff --git a/containers/src/Data/Map/Strict.hs b/containers/src/Data/Map/Strict.hs index 649898850..632a1bf70 100644 --- a/containers/src/Data/Map/Strict.hs +++ b/containers/src/Data/Map/Strict.hs @@ -246,6 +246,7 @@ module Data.Map.Strict , filterWithKey , restrictKeys , withoutKeys + , partitionKeys , partition , partitionWithKey diff --git a/containers/src/Data/Map/Strict/Internal.hs b/containers/src/Data/Map/Strict/Internal.hs index 21afe2d91..9e6be1235 100644 --- a/containers/src/Data/Map/Strict/Internal.hs +++ b/containers/src/Data/Map/Strict/Internal.hs @@ -256,6 +256,7 @@ module Data.Map.Strict.Internal , filterWithKey , restrictKeys , withoutKeys + , partitionKeys , partition , partitionWithKey , takeWhileAntitone @@ -418,7 +419,8 @@ import Data.Map.Internal , toDescList , union , unions - , withoutKeys ) + , withoutKeys + , partitionKeys ) import Data.Map.Internal.Debug (valid) From e66de6a05fa13884465a303e76159cab36b67c01 Mon Sep 17 00:00:00 2001 From: Sergey Vinokurov Date: Sun, 15 Oct 2023 11:31:32 +0100 Subject: [PATCH 2/4] =?UTF-8?q?Introduce=20recursive=20worker=20to=20?= =?UTF-8?q?=E2=80=98splitMember=E2=80=99=20to=20increase=20inlining=20chan?= =?UTF-8?q?ces?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- containers-tests/containers-tests.cabal | 1 + containers/containers.cabal | 1 + containers/src/Data/Map/Internal.hs | 3 +-- containers/src/Data/Set/Internal.hs | 25 +++++++++++-------- .../Utils/Containers/Internal/StrictTriple.hs | 17 +++++++++++++ 5 files changed, 35 insertions(+), 12 deletions(-) create mode 100644 containers/src/Utils/Containers/Internal/StrictTriple.hs diff --git a/containers-tests/containers-tests.cabal b/containers-tests/containers-tests.cabal index da7c76a0f..289f31dc2 100644 --- a/containers-tests/containers-tests.cabal +++ b/containers-tests/containers-tests.cabal @@ -128,6 +128,7 @@ library Utils.Containers.Internal.State Utils.Containers.Internal.StrictMaybe Utils.Containers.Internal.EqOrdUtil + Utils.Containers.Internal.StrictTriple if impl(ghc) other-modules: diff --git a/containers/containers.cabal b/containers/containers.cabal index 12185a9f5..eeb72dae6 100644 --- a/containers/containers.cabal +++ b/containers/containers.cabal @@ -83,6 +83,7 @@ Library Utils.Containers.Internal.PtrEquality Utils.Containers.Internal.Coercions Utils.Containers.Internal.EqOrdUtil + Utils.Containers.Internal.StrictTriple if impl(ghc) other-modules: Utils.Containers.Internal.TypeError diff --git a/containers/src/Data/Map/Internal.hs b/containers/src/Data/Map/Internal.hs index 2baace297..81a912502 100644 --- a/containers/src/Data/Map/Internal.hs +++ b/containers/src/Data/Map/Internal.hs @@ -400,6 +400,7 @@ import qualified Data.Set.Internal as Set import Data.Set.Internal (Set) import Utils.Containers.Internal.PtrEquality (ptrEq) import Utils.Containers.Internal.StrictPair +import Utils.Containers.Internal.StrictTriple import Utils.Containers.Internal.StrictMaybe import Utils.Containers.Internal.BitQueue import Utils.Containers.Internal.EqOrdUtil (EqM(..), OrdM(..)) @@ -4048,8 +4049,6 @@ splitMember k0 m = case go k0 m of {-# INLINABLE splitMember #-} #endif -data StrictTriple a b c = StrictTriple !a !b !c - {-------------------------------------------------------------------- Utility functions that maintain the balance properties of the tree. All constructors assume that all values in [l] < [k] and all values diff --git a/containers/src/Data/Set/Internal.hs b/containers/src/Data/Set/Internal.hs index 6bc472f3d..60b0b4799 100644 --- a/containers/src/Data/Set/Internal.hs +++ b/containers/src/Data/Set/Internal.hs @@ -256,6 +256,7 @@ import Data.List.NonEmpty (NonEmpty(..)) #endif import Utils.Containers.Internal.StrictPair +import Utils.Containers.Internal.StrictTriple import Utils.Containers.Internal.PtrEquality import Utils.Containers.Internal.EqOrdUtil (EqM(..), OrdM(..)) @@ -1433,16 +1434,20 @@ splitS x (Bin _ y l r) -- | \(O(\log n)\). Performs a 'split' but also returns whether the pivot -- element was found in the original set. splitMember :: Ord a => a -> Set a -> (Set a,Bool,Set a) -splitMember _ Tip = (Tip, False, Tip) -splitMember x (Bin _ y l r) - = case compare x y of - LT -> let (lt, found, gt) = splitMember x l - !gt' = link y gt r - in (lt, found, gt') - GT -> let (lt, found, gt) = splitMember x r - !lt' = link y l lt - in (lt', found, gt) - EQ -> (l, True, r) +splitMember k0 s = case go k0 s of + StrictTriple l b r -> (l, b, r) + where + go :: Ord a => a -> Set a -> StrictTriple (Set a) Bool (Set a) + go _ Tip = StrictTriple Tip False Tip + go x (Bin _ y l r) + = case compare x y of + LT -> let StrictTriple lt found gt = go x l + !gt' = link y gt r + in StrictTriple lt found gt' + GT -> let StrictTriple lt found gt = go x r + !lt' = link y l lt + in StrictTriple lt' found gt + EQ -> StrictTriple l True r #if __GLASGOW_HASKELL__ {-# INLINABLE splitMember #-} #endif diff --git a/containers/src/Utils/Containers/Internal/StrictTriple.hs b/containers/src/Utils/Containers/Internal/StrictTriple.hs new file mode 100644 index 000000000..d55e3d1a4 --- /dev/null +++ b/containers/src/Utils/Containers/Internal/StrictTriple.hs @@ -0,0 +1,17 @@ +{-# LANGUAGE CPP #-} +#if !defined(TESTING) && defined(__GLASGOW_HASKELL__) +{-# LANGUAGE Safe #-} +#endif + +#include "containers.h" + +-- | A strict triple + +module Utils.Containers.Internal.StrictTriple (StrictTriple(..)) where + +-- | The same as a regular Haskell tuple, but +-- +-- @ +-- StrictTriple x y _|_ = StrictTriple x _|_ z = StrictTriple _|_ y z = _|_ +-- @ +data StrictTriple a b c = StrictTriple !a !b !c From b6d6533eb72aac9e1208d277fba3008f7a12a6e7 Mon Sep 17 00:00:00 2001 From: Sergey Vinokurov Date: Wed, 3 Apr 2024 01:10:55 +0100 Subject: [PATCH 3/4] Review suggestions: prefer toplevel functions to internal workers --- containers/src/Data/Map/Internal.hs | 57 ++++++++++--------- containers/src/Data/Set/Internal.hs | 28 ++++----- .../Utils/Containers/Internal/StrictTriple.hs | 2 - 3 files changed, 45 insertions(+), 42 deletions(-) diff --git a/containers/src/Data/Map/Internal.hs b/containers/src/Data/Map/Internal.hs index 81a912502..46be61227 100644 --- a/containers/src/Data/Map/Internal.hs +++ b/containers/src/Data/Map/Internal.hs @@ -1979,38 +1979,41 @@ withoutKeys m (Set.Bin _ k ls rs) = case splitMember k m of -- @ partitionKeys :: forall k a. Ord k => Map k a -> Set k -> (Map k a, Map k a) partitionKeys xs ys = - case go xs ys of + case partitionKeysWorker xs ys of xs' :*: ys' -> (xs', ys') - where - go :: Map k a -> Set k -> StrictPair (Map k a) (Map k a) - go Tip _ = Tip :*: Tip - go m Set.Tip = Tip :*: m - go m@(Bin _ k x lm rm) s@Set.Bin{} = - case b of - True -> with :*: without - where - with = - if lmWith `ptrEq` lm && rmWith `ptrEq` rm - then m - else link k x lmWith rmWith - without = - link2 lmWithout rmWithout - False -> with :*: without - where - with = link2 lmWith rmWith - without = - if lmWithout `ptrEq` lm && rmWithout `ptrEq` rm - then m - else link k x lmWithout rmWithout - where - !(lmWith :*: lmWithout) = go lm ls' - !(rmWith :*: rmWithout) = go rm rs' - - !(!ls', b, !rs') = Set.splitMember k s #if __GLASGOW_HASKELL__ {-# INLINABLE partitionKeys #-} #endif +partitionKeysWorker :: Ord k => Map k a -> Set k -> StrictPair (Map k a) (Map k a) +partitionKeysWorker Tip _ = Tip :*: Tip +partitionKeysWorker m Set.Tip = Tip :*: m +partitionKeysWorker m@(Bin _ k x lm rm) s@Set.Bin{} = + case b of + True -> with :*: without + where + with = + if lmWith `ptrEq` lm && rmWith `ptrEq` rm + then m + else link k x lmWith rmWith + without = + link2 lmWithout rmWithout + False -> with :*: without + where + with = link2 lmWith rmWith + without = + if lmWithout `ptrEq` lm && rmWithout `ptrEq` rm + then m + else link k x lmWithout rmWithout + where + !(lmWith :*: lmWithout) = partitionKeysWorker lm ls' + !(rmWith :*: rmWithout) = partitionKeysWorker rm rs' + + !(!ls', b, !rs') = Set.splitMember k s +#if __GLASGOW_HASKELL__ +{-# INLINABLE partitionKeysWorker #-} +#endif + -- | \(O(n+m)\). Difference with a combining function. -- When two equal keys are -- encountered, the combining function is applied to the values of these keys. diff --git a/containers/src/Data/Set/Internal.hs b/containers/src/Data/Set/Internal.hs index 60b0b4799..432b674f6 100644 --- a/containers/src/Data/Set/Internal.hs +++ b/containers/src/Data/Set/Internal.hs @@ -1431,23 +1431,25 @@ splitS x (Bin _ y l r) EQ -> (l :*: r) {-# INLINABLE splitS #-} +splitMemberS :: Ord a => a -> Set a -> StrictTriple (Set a) Bool (Set a) +splitMemberS x = go + where + go Tip = StrictTriple Tip False Tip + go (Bin _ y l r) = case compare x y of + LT -> let StrictTriple lt found gt = splitMemberS x l + in StrictTriple lt found (link y gt r) + GT -> let StrictTriple lt found gt = splitMemberS x r + in StrictTriple (link y l lt) found gt + EQ -> StrictTriple l True r +#if __GLASGOW_HASKELL__ +{-# INLINABLE splitMemberS #-} +#endif + -- | \(O(\log n)\). Performs a 'split' but also returns whether the pivot -- element was found in the original set. splitMember :: Ord a => a -> Set a -> (Set a,Bool,Set a) -splitMember k0 s = case go k0 s of +splitMember k0 s = case splitMemberS k0 s of StrictTriple l b r -> (l, b, r) - where - go :: Ord a => a -> Set a -> StrictTriple (Set a) Bool (Set a) - go _ Tip = StrictTriple Tip False Tip - go x (Bin _ y l r) - = case compare x y of - LT -> let StrictTriple lt found gt = go x l - !gt' = link y gt r - in StrictTriple lt found gt' - GT -> let StrictTriple lt found gt = go x r - !lt' = link y l lt - in StrictTriple lt' found gt - EQ -> StrictTriple l True r #if __GLASGOW_HASKELL__ {-# INLINABLE splitMember #-} #endif diff --git a/containers/src/Utils/Containers/Internal/StrictTriple.hs b/containers/src/Utils/Containers/Internal/StrictTriple.hs index d55e3d1a4..45523f81f 100644 --- a/containers/src/Utils/Containers/Internal/StrictTriple.hs +++ b/containers/src/Utils/Containers/Internal/StrictTriple.hs @@ -3,8 +3,6 @@ {-# LANGUAGE Safe #-} #endif -#include "containers.h" - -- | A strict triple module Utils.Containers.Internal.StrictTriple (StrictTriple(..)) where From 9f7c5bd9ee38e580f3f8835f18915b2855420cf4 Mon Sep 17 00:00:00 2001 From: Sergey Vinokurov Date: Sat, 6 Apr 2024 11:21:05 +0100 Subject: [PATCH 4/4] More review suggestions --- containers-tests/benchmarks/Map.hs | 11 ++++------- containers/src/Data/Map/Internal.hs | 3 +-- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/containers-tests/benchmarks/Map.hs b/containers-tests/benchmarks/Map.hs index 4cbe36577..392c15bb2 100644 --- a/containers-tests/benchmarks/Map.hs +++ b/containers-tests/benchmarks/Map.hs @@ -5,7 +5,7 @@ module Main where import Control.Applicative (Const(Const, getConst), pure) import Control.DeepSeq (rnf) import Control.Exception (evaluate) -import Test.Tasty.Bench (bench, defaultMain, whnf, nf, bcompare) +import Test.Tasty.Bench (bench, defaultMain, whnf, nf) import Data.Functor.Identity (Identity(..)) import Data.List (foldl') import qualified Data.Map as M @@ -15,7 +15,6 @@ import Data.Maybe (fromMaybe) import Data.Functor ((<$)) import Data.Coerce import Prelude hiding (lookup) -import Utils.Containers.Internal.StrictPair main = do let m = M.fromAscList elems :: M.Map Int Int @@ -102,11 +101,9 @@ main = do , bench "eq" $ whnf (\m' -> m' == m') m -- worst case, compares everything , bench "compare" $ whnf (\m' -> compare m' m') m -- worst case, compares everything - , bench "restrictKeys+withoutKeys" - $ whnf (\ks -> M.restrictKeys m ks :*: M.withoutKeys m ks) m_odd_keys - , bcompare "/restrictKeys+withoutKeys/" - $ bench "partitionKeys" - $ whnf (M.partitionKeys m) m_odd_keys + , bench "restrictKeys" $ whnf (M.restrictKeys m) m_odd_keys + , bench "withoutKeys" $ whnf (M.withoutKeys m) m_odd_keys + , bench "partitionKeys" $ whnf (M.partitionKeys m) m_odd_keys ] where bound = 2^12 diff --git a/containers/src/Data/Map/Internal.hs b/containers/src/Data/Map/Internal.hs index 46be61227..8cd484536 100644 --- a/containers/src/Data/Map/Internal.hs +++ b/containers/src/Data/Map/Internal.hs @@ -7,7 +7,6 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE Trustworthy #-} {-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE ScopedTypeVariables #-} #define USE_MAGIC_PROXY 1 #endif @@ -1977,7 +1976,7 @@ withoutKeys m (Set.Bin _ k ls rs) = case splitMember k m of -- @ -- m \`partitionKeys\` s = (m ``restrictKeys`` s, m ``withoutKeys`` s) -- @ -partitionKeys :: forall k a. Ord k => Map k a -> Set k -> (Map k a, Map k a) +partitionKeys :: Ord k => Map k a -> Set k -> (Map k a, Map k a) partitionKeys xs ys = case partitionKeysWorker xs ys of xs' :*: ys' -> (xs', ys')