@@ -79,6 +79,7 @@ module Data.Vector.Generic.Mutable (
7979) where
8080
8181import Control.Monad ((<=<) )
82+ import Control.Monad.ST
8283import Data.Vector.Generic.Mutable.Base
8384import qualified Data.Vector.Generic.Base as V
8485
@@ -91,7 +92,7 @@ import Data.Vector.Fusion.Bundle.Size
9192import Data.Vector.Fusion.Util ( delay_inline )
9293import Data.Vector.Internal.Check
9394
94- import Control.Monad.Primitive ( PrimMonad (.. ), RealWorld , stToPrim )
95+ import Control.Monad.Primitive ( PrimMonad (.. ), stToPrim )
9596
9697import Prelude
9798 ( Ord , Monad , Bool (.. ), Int , Maybe (.. ), Either (.. ), Ordering (.. )
@@ -106,8 +107,7 @@ import Data.Bits ( Bits(shiftR) )
106107-- Internal functions
107108-- ------------------
108109
109- unsafeAppend1 :: (PrimMonad m , MVector v a )
110- => v (PrimState m ) a -> Int -> a -> m (v (PrimState m ) a )
110+ unsafeAppend1 :: (MVector v a ) => v s a -> Int -> a -> ST s (v s a )
111111{-# INLINE_INNER unsafeAppend1 #-}
112112 -- NOTE: The case distinction has to be on the outside because
113113 -- GHC creates a join point for the unsafeWrite even when everything
@@ -122,8 +122,7 @@ unsafeAppend1 v i x
122122 checkIndex Internal i (length v') $ unsafeWrite v' i x
123123 return v'
124124
125- unsafePrepend1 :: (PrimMonad m , MVector v a )
126- => v (PrimState m ) a -> Int -> a -> m (v (PrimState m ) a , Int )
125+ unsafePrepend1 :: (MVector v a ) => v s a -> Int -> a -> ST s (v s a , Int )
127126{-# INLINE_INNER unsafePrepend1 #-}
128127unsafePrepend1 v i x
129128 | i /= 0 = do
@@ -207,7 +206,7 @@ unstream :: (PrimMonad m, MVector v a)
207206 => Bundle u a -> m (v (PrimState m ) a )
208207-- NOTE: replace INLINE_FUSED by INLINE? (also in unstreamR)
209208{-# INLINE_FUSED unstream #-}
210- unstream s = munstream (Bundle. lift s)
209+ unstream s = stToPrim $ munstream (Bundle. lift s)
211210
212211-- | Create a new mutable vector and fill it with elements from the monadic
213212-- stream. The vector will grow exponentially if the maximum size of the stream
@@ -243,9 +242,8 @@ munstreamUnknown s
243242 $ unsafeSlice 0 n v'
244243 where
245244 {-# INLINE_INNER put #-}
246- put (v,i) x = do
247- v' <- unsafeAppend1 v i x
248- return (v',i+ 1 )
245+ put (v,i) x = stToPrim $ do v' <- unsafeAppend1 v i x
246+ return (v',i+ 1 )
249247
250248
251249-- | Create a new mutable vector and fill it with elements from the 'Bundle'.
@@ -255,7 +253,7 @@ vunstream :: (PrimMonad m, V.Vector v a)
255253 => Bundle v a -> m (V. Mutable v (PrimState m ) a )
256254-- NOTE: replace INLINE_FUSED by INLINE? (also in unstreamR)
257255{-# INLINE_FUSED vunstream #-}
258- vunstream s = vmunstream (Bundle. lift s)
256+ vunstream s = stToPrim $ vmunstream (Bundle. lift s)
259257
260258-- | Create a new mutable vector and fill it with elements from the monadic
261259-- stream. The vector will grow exponentially if the maximum size of the stream
@@ -311,7 +309,7 @@ unstreamR :: (PrimMonad m, MVector v a)
311309 => Bundle u a -> m (v (PrimState m ) a )
312310-- NOTE: replace INLINE_FUSED by INLINE? (also in unstream)
313311{-# INLINE_FUSED unstreamR #-}
314- unstreamR s = munstreamR (Bundle. lift s)
312+ unstreamR s = stToPrim $ munstreamR (Bundle. lift s)
315313
316314-- | Create a new mutable vector and fill it with elements from the monadic
317315-- stream from right to left. The vector will grow exponentially if the maximum
@@ -350,7 +348,7 @@ munstreamRUnknown s
350348 $ unsafeSlice i (n- i) v'
351349 where
352350 {-# INLINE_INNER put #-}
353- put (v,i) x = unsafePrepend1 v i x
351+ put (v,i) x = stToPrim $ unsafePrepend1 v i x
354352
355353-- Length
356354-- ------
@@ -563,10 +561,9 @@ enlarge_delta :: MVector v a => v s a -> Int
563561enlarge_delta v = max (length v) 1
564562
565563-- | Grow a vector logarithmically.
566- enlarge :: (PrimMonad m , MVector v a )
567- => v (PrimState m ) a -> m (v (PrimState m ) a )
564+ enlarge :: (MVector v a ) => v s a -> ST s (v s a )
568565{-# INLINE enlarge #-}
569- enlarge v = stToPrim $ do
566+ enlarge v = do
570567 vnew <- unsafeGrow v by
571568 basicInitialize $ basicUnsafeSlice (length v) by vnew
572569 return vnew
@@ -996,10 +993,10 @@ unsafeMove dst src = check Unsafe "length mismatch" (length dst == length src)
996993accum :: forall m v a b u . (HasCallStack , PrimMonad m , MVector v a )
997994 => (a -> b -> a ) -> v (PrimState m ) a -> Bundle u (Int , b ) -> m ()
998995{-# INLINE accum #-}
999- accum f ! v s = Bundle. mapM_ upd s
996+ accum f ! v s = stToPrim $ Bundle. mapM_ upd s
1000997 where
1001998 {-# INLINE_INNER upd #-}
1002- upd :: HasCallStack => (Int , b ) -> m ()
999+ upd :: HasCallStack => (Int , b ) -> ST ( PrimState m ) ()
10031000 upd (i,b) = do
10041001 a <- checkIndex Bounds i n $ unsafeRead v i
10051002 unsafeWrite v i (f a b)
@@ -1008,18 +1005,18 @@ accum f !v s = Bundle.mapM_ upd s
10081005update :: forall m v a u . (HasCallStack , PrimMonad m , MVector v a )
10091006 => v (PrimState m ) a -> Bundle u (Int , a ) -> m ()
10101007{-# INLINE update #-}
1011- update ! v s = Bundle. mapM_ upd s
1008+ update ! v s = stToPrim $ Bundle. mapM_ upd s
10121009 where
10131010 {-# INLINE_INNER upd #-}
1014- upd :: HasCallStack => (Int , a ) -> m ()
1011+ upd :: HasCallStack => (Int , a ) -> ST ( PrimState m ) ()
10151012 upd (i,b) = checkIndex Bounds i n $ unsafeWrite v i b
10161013
10171014 ! n = length v
10181015
10191016unsafeAccum :: (PrimMonad m , MVector v a )
10201017 => (a -> b -> a ) -> v (PrimState m ) a -> Bundle u (Int , b ) -> m ()
10211018{-# INLINE unsafeAccum #-}
1022- unsafeAccum f ! v s = Bundle. mapM_ upd s
1019+ unsafeAccum f ! v s = stToPrim $ Bundle. mapM_ upd s
10231020 where
10241021 {-# INLINE_INNER upd #-}
10251022 upd (i,b) = do
@@ -1028,17 +1025,17 @@ unsafeAccum f !v s = Bundle.mapM_ upd s
10281025 ! n = length v
10291026
10301027unsafeUpdate :: (PrimMonad m , MVector v a )
1031- => v (PrimState m ) a -> Bundle u (Int , a ) -> m ()
1028+ => v (PrimState m ) a -> Bundle u (Int , a ) -> m ()
10321029{-# INLINE unsafeUpdate #-}
1033- unsafeUpdate ! v s = Bundle. mapM_ upd s
1030+ unsafeUpdate ! v s = stToPrim $ Bundle. mapM_ upd s
10341031 where
10351032 {-# INLINE_INNER upd #-}
10361033 upd (i,b) = checkIndex Unsafe i n $ unsafeWrite v i b
10371034 ! n = length v
10381035
10391036reverse :: (PrimMonad m , MVector v a ) => v (PrimState m ) a -> m ()
10401037{-# INLINE reverse #-}
1041- reverse ! v = reverse_loop 0 (length v - 1 )
1038+ reverse ! v = stToPrim $ reverse_loop 0 (length v - 1 )
10421039 where
10431040 reverse_loop i j | i < j = do
10441041 unsafeSwap v i j
@@ -1048,11 +1045,11 @@ reverse !v = reverse_loop 0 (length v - 1)
10481045unstablePartition :: forall m v a . (PrimMonad m , MVector v a )
10491046 => (a -> Bool ) -> v (PrimState m ) a -> m Int
10501047{-# INLINE unstablePartition #-}
1051- unstablePartition f ! v = from_left 0 (length v)
1048+ unstablePartition f ! v = stToPrim $ from_left 0 (length v)
10521049 where
10531050 -- NOTE: GHC 6.10.4 panics without the signatures on from_left and
10541051 -- from_right
1055- from_left :: Int -> Int -> m Int
1052+ from_left :: Int -> Int -> ST ( PrimState m ) Int
10561053 from_left i j
10571054 | i == j = return i
10581055 | otherwise = do
@@ -1061,7 +1058,7 @@ unstablePartition f !v = from_left 0 (length v)
10611058 then from_left (i+ 1 ) j
10621059 else from_right i (j- 1 )
10631060
1064- from_right :: Int -> Int -> m Int
1061+ from_right :: Int -> Int -> ST ( PrimState m ) Int
10651062 from_right i j
10661063 | i == j = return i
10671064 | otherwise = do
@@ -1078,7 +1075,8 @@ unstablePartitionBundle :: (PrimMonad m, MVector v a)
10781075 => (a -> Bool ) -> Bundle u a -> m (v (PrimState m ) a , v (PrimState m ) a )
10791076{-# INLINE unstablePartitionBundle #-}
10801077unstablePartitionBundle f s
1081- = case upperBound (Bundle. size s) of
1078+ = stToPrim
1079+ $ case upperBound (Bundle. size s) of
10821080 Just n -> unstablePartitionMax f s n
10831081 Nothing -> partitionUnknown f s
10841082
@@ -1087,7 +1085,7 @@ unstablePartitionMax :: (PrimMonad m, MVector v a)
10871085 -> m (v (PrimState m ) a , v (PrimState m ) a )
10881086{-# INLINE unstablePartitionMax #-}
10891087unstablePartitionMax f s n
1090- = do
1088+ = stToPrim $ do
10911089 v <- checkLength Internal n $ unsafeNew n
10921090 let {-# INLINE_INNER put #-}
10931091 put (i, j) x
@@ -1105,15 +1103,15 @@ partitionBundle :: (PrimMonad m, MVector v a)
11051103 => (a -> Bool ) -> Bundle u a -> m (v (PrimState m ) a , v (PrimState m ) a )
11061104{-# INLINE partitionBundle #-}
11071105partitionBundle f s
1108- = case upperBound (Bundle. size s) of
1106+ = stToPrim
1107+ $ case upperBound (Bundle. size s) of
11091108 Just n -> partitionMax f s n
11101109 Nothing -> partitionUnknown f s
11111110
11121111partitionMax :: (PrimMonad m , MVector v a )
11131112 => (a -> Bool ) -> Bundle u a -> Int -> m (v (PrimState m ) a , v (PrimState m ) a )
11141113{-# INLINE partitionMax #-}
1115- partitionMax f s n
1116- = do
1114+ partitionMax f s n = stToPrim $ do
11171115 v <- checkLength Internal n $ unsafeNew n
11181116
11191117 let {-# INLINE_INNER put #-}
@@ -1138,8 +1136,7 @@ partitionMax f s n
11381136partitionUnknown :: (PrimMonad m , MVector v a )
11391137 => (a -> Bool ) -> Bundle u a -> m (v (PrimState m ) a , v (PrimState m ) a )
11401138{-# INLINE partitionUnknown #-}
1141- partitionUnknown f s
1142- = do
1139+ partitionUnknown f s = stToPrim $ do
11431140 v1 <- unsafeNew 0
11441141 v2 <- unsafeNew 0
11451142 (v1', n1, v2', n2) <- Bundle. foldM' put (v1, 0 , v2, 0 ) s
@@ -1165,15 +1162,16 @@ partitionWithBundle :: (PrimMonad m, MVector v a, MVector v b, MVector v c)
11651162 => (a -> Either b c ) -> Bundle u a -> m (v (PrimState m ) b , v (PrimState m ) c )
11661163{-# INLINE partitionWithBundle #-}
11671164partitionWithBundle f s
1168- = case upperBound (Bundle. size s) of
1165+ = stToPrim
1166+ $ case upperBound (Bundle. size s) of
11691167 Just n -> partitionWithMax f s n
11701168 Nothing -> partitionWithUnknown f s
11711169
11721170partitionWithMax :: (PrimMonad m , MVector v a , MVector v b , MVector v c )
11731171 => (a -> Either b c ) -> Bundle u a -> Int -> m (v (PrimState m ) b , v (PrimState m ) c )
11741172{-# INLINE partitionWithMax #-}
11751173partitionWithMax f s n
1176- = do
1174+ = stToPrim $ do
11771175 v1 <- unsafeNew n
11781176 v2 <- unsafeNew n
11791177 let {-# INLINE_INNER put #-}
@@ -1194,7 +1192,7 @@ partitionWithUnknown :: forall m v u a b c.
11941192 => (a -> Either b c ) -> Bundle u a -> m (v (PrimState m ) b , v (PrimState m ) c )
11951193{-# INLINE partitionWithUnknown #-}
11961194partitionWithUnknown f s
1197- = do
1195+ = stToPrim $ do
11981196 v1 <- unsafeNew 0
11991197 v2 <- unsafeNew 0
12001198 (v1', n1, v2', n2) <- Bundle. foldM' put (v1, 0 , v2, 0 ) s
@@ -1204,14 +1202,14 @@ partitionWithUnknown f s
12041202 where
12051203 put :: (v (PrimState m ) b , Int , v (PrimState m ) c , Int )
12061204 -> a
1207- -> m (v (PrimState m ) b , Int , v (PrimState m ) c , Int )
1205+ -> ST ( PrimState m ) (v (PrimState m ) b , Int , v (PrimState m ) c , Int )
12081206 {-# INLINE_INNER put #-}
12091207 put (v1, i1, v2, i2) x = case f x of
12101208 Left b -> do
1211- v1' <- unsafeAppend1 v1 i1 b
1209+ v1' <- stToPrim $ unsafeAppend1 v1 i1 b
12121210 return (v1', i1+ 1 , v2, i2)
12131211 Right c -> do
1214- v2' <- unsafeAppend1 v2 i2 c
1212+ v2' <- stToPrim $ unsafeAppend1 v2 i2 c
12151213 return (v1, i1, v2', i2+ 1 )
12161214
12171215-- Modifying vectors
0 commit comments