Skip to content

Commit 4541afc

Browse files
committed
Refactoring RWST to an ADT
Improves performance to put it on par with StateT.
1 parent e31ddaf commit 4541afc

File tree

3 files changed

+44
-44
lines changed

3 files changed

+44
-44
lines changed

docs/Control/Monad/RWS/Trans.md

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,7 @@ This module defines the reader-writer-state monad transformer, `RWST`.
55
#### `See`
66

77
``` purescript
8-
type See s a w = { state :: s, result :: a, log :: w }
9-
```
10-
11-
#### `mkSee`
12-
13-
``` purescript
14-
mkSee :: forall s a w. (Monoid w) => s -> a -> w -> See s a w
8+
data See state result writer
159
```
1610

1711
#### `RWST`
@@ -26,7 +20,7 @@ of `ReaderT`, `WriterT` and `StateT` into a single monad transformer.
2620

2721
##### Instances
2822
``` purescript
29-
instance functorRWST :: (Functor m) => Functor (RWST r w s m)
23+
instance functorRWST :: (Functor m, Monoid w) => Functor (RWST r w s m)
3024
instance applyRWST :: (Bind m, Monoid w) => Apply (RWST r w s m)
3125
instance bindRWST :: (Bind m, Monoid w) => Bind (RWST r w s m)
3226
instance applicativeRWST :: (Monad m, Monoid w) => Applicative (RWST r w s m)

src/Control/Monad/RWS/Trans.purs

Lines changed: 25 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
-- | This module defines the reader-writer-state monad transformer, `RWST`.
22

33
module Control.Monad.RWS.Trans
4-
( See(), mkSee
4+
( See()
55
, RWST(..), runRWST, evalRWST, execRWST, mapRWST, withRWST
66
, module Control.Monad.Trans
77
, module Control.Monad.RWS.Class
@@ -22,14 +22,7 @@ import Control.Monad.State.Class
2222
import Control.Monad.Trans
2323
import Control.Monad.Writer.Class
2424

25-
type See s a w =
26-
{ state :: s
27-
, result :: a
28-
, log :: w
29-
}
30-
31-
mkSee :: forall s a w. (Monoid w) => s -> a -> w -> See s a w
32-
mkSee s a w = { state: s, result: a, log: w }
25+
data See state result writer = See state result writer
3326

3427
-- | The reader-writer-state monad transformer, which combines the operations
3528
-- | of `ReaderT`, `WriterT` and `StateT` into a single monad transformer.
@@ -41,11 +34,11 @@ runRWST (RWST x) = x
4134

4235
-- | Run a computation in the `RWST` monad, discarding the final state.
4336
evalRWST :: forall r w s m a. (Monad m) => RWST r w s m a -> r -> s -> m (Tuple a w)
44-
evalRWST m r s = runRWST m r s >>= \see -> return (Tuple see.result see.log)
37+
evalRWST m r s = runRWST m r s >>= \(See _ result writer) -> return (Tuple result writer)
4538

4639
-- | Run a computation in the `RWST` monad, discarding the result.
4740
execRWST :: forall r w s m a. (Monad m) => RWST r w s m a -> r -> s -> m (Tuple s w)
48-
execRWST m r s = runRWST m r s >>= \see -> return (Tuple see.state see.log)
41+
execRWST m r s = runRWST m r s >>= \(See state _ writer) -> return (Tuple state writer)
4942

5043
-- | Change the result and accumulator types in a `RWST` monad action.
5144
mapRWST :: forall r w1 w2 s m1 m2 a1 a2. (m1 (See s a1 w1) -> m2 (See s a2 w2)) -> RWST r w1 s m1 a1 -> RWST r w2 s m2 a2
@@ -55,43 +48,43 @@ mapRWST f m = RWST \r s -> f $ runRWST m r s
5548
withRWST :: forall r1 r2 w s m a. (r2 -> s -> Tuple r1 s) -> RWST r1 w s m a -> RWST r2 w s m a
5649
withRWST f m = RWST \r s -> uncurry (runRWST m) (f r s)
5750

58-
instance functorRWST :: (Functor m) => Functor (RWST r w s m) where
59-
map f m = RWST \r s -> (\see -> see{result = f see.result}) <$> runRWST m r s
51+
instance functorRWST :: (Functor m, Monoid w) => Functor (RWST r w s m) where
52+
map f m = RWST \r s -> (\(See state result writer) -> See state (f result) writer) <$> runRWST m r s
6053

6154
instance applyRWST :: (Bind m, Monoid w) => Apply (RWST r w s m) where
6255
apply f m = RWST \r s ->
63-
runRWST f r s >>= \{state = s', result = f', log = w'} ->
64-
runRWST m r s' <#> \{state = s'', result = a'', log = w''} ->
65-
mkSee s'' (f' a'') (w' ++ w'')
56+
runRWST f r s >>= \(See s' f' w') ->
57+
runRWST m r s' <#> \(See s'' a'' w'') ->
58+
See s'' (f' a'') (w' ++ w'')
6659

6760
instance bindRWST :: (Bind m, Monoid w) => Bind (RWST r w s m) where
6861
bind m f = RWST \r s ->
69-
runRWST m r s >>= \{result = a, state = s', log = l} ->
70-
runRWST (f a) r s' <#> \see' ->
71-
see' { log = l ++ see'.log }
62+
runRWST m r s >>= \(See s' a w) ->
63+
runRWST (f a) r s' <#> \(See state result writer) ->
64+
See state result (w ++ writer)
7265

7366
instance applicativeRWST :: (Monad m, Monoid w) => Applicative (RWST r w s m) where
74-
pure a = RWST \_ s -> pure $ mkSee s a mempty
67+
pure a = RWST \_ s -> pure $ See s a mempty
7568

7669
instance monadRWST :: (Monad m, Monoid w) => Monad (RWST r w s m)
7770

7871
instance monadTransRWST :: (Monoid w) => MonadTrans (RWST r w s) where
79-
lift m = RWST \_ s -> m >>= \a -> return $ mkSee s a mempty
72+
lift m = RWST \_ s -> m >>= \a -> return $ See s a mempty
8073

8174
instance monadEffRWS :: (Monad m, Monoid w, MonadEff eff m) => MonadEff eff (RWST r w s m) where
8275
liftEff = lift <<< liftEff
8376

8477
instance monadReaderRWST :: (Monad m, Monoid w) => MonadReader r (RWST r w s m) where
85-
ask = RWST \r s -> pure $ mkSee s r mempty
78+
ask = RWST \r s -> pure $ See s r mempty
8679
local f m = RWST \r s -> runRWST m (f r) s
8780

8881
instance monadStateRWST :: (Monad m, Monoid w) => MonadState s (RWST r w s m) where
89-
state f = RWST \_ s -> case f s of Tuple a s' -> pure $ mkSee s' a mempty
82+
state f = RWST \_ s -> case f s of Tuple a s' -> pure $ See s' a mempty
9083

9184
instance monadWriterRWST :: (Monad m, Monoid w) => MonadWriter w (RWST r w s m) where
92-
writer (Tuple a w) = RWST \_ s -> pure $ { state: s, result: a, log: w }
93-
listen m = RWST \r s -> runRWST m r s >>= \{ state: s', result: a, log: w} -> pure { state: s', result: Tuple a w, log: w }
94-
pass m = RWST \r s -> runRWST m r s >>= \{ result: Tuple a f, state: s', log: w} -> pure { state: s', result: a, log: f w }
85+
writer (Tuple a w) = RWST \_ s -> pure $ See s a w
86+
listen m = RWST \r s -> runRWST m r s >>= \(See s' a w) -> pure $ See s' (Tuple a w) w
87+
pass m = RWST \r s -> runRWST m r s >>= \(See s' (Tuple a f) w) -> pure $ See s' a (f w)
9588

9689
instance monadRWSRWST :: (Monad m, Monoid w) => MonadRWS r w s (RWST r w s m)
9790

@@ -100,10 +93,10 @@ instance monadErrorRWST :: (MonadError e m, Monoid w) => MonadError e (RWST r w
10093
catchError m h = RWST $ \r s -> catchError (runRWST m r s) (\e -> runRWST (h e) r s)
10194

10295
instance monadRecRWST :: (Monoid w, MonadRec m) => MonadRec (RWST r w s m) where
103-
tailRecM k a = RWST \r s -> tailRecM (k' r) { writer: mempty, state: s, result: a }
96+
tailRecM k a = RWST \r s -> tailRecM (k' r) (See s a mempty)
10497
where
105-
k' r o = do
106-
see <- runRWST (k o.result) r o.state
107-
return case see.result of
108-
Left a -> Left { state: see.state, result: a, writer: o.writer <> see.log }
109-
Right b -> Right (mkSee see.state b (o.writer <> see.log))
98+
k' r (See state result writer) = do
99+
See state' result' writer' <- runRWST (k result) r state
100+
return case result' of
101+
Left a -> Left (See state' a (writer <> writer'))
102+
Right b -> Right (See state' b (writer <> writer'))

test/Example/RWS.purs

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import Control.Monad.RWS
88
import Control.Monad.RWS.Trans
99
import Control.Monad.Rec.Class
1010
import Control.Monad.State
11+
import Control.Monad.State.Trans
1112
import Control.Monad.Writer
1213

1314
import Data.Either
@@ -24,12 +25,24 @@ loop n = tailRecM go n
2425
put (x + 1)
2526
return (Left (n - 1))
2627

28+
loopState :: Int -> StateT Int Identity Unit
29+
loopState n = tailRecM go n
30+
where
31+
go 0 = do
32+
return (Right unit)
33+
go n = do
34+
x <- get
35+
put (x + 1)
36+
return (Left (n - 1))
37+
2738
main = do
2839
t1 <- t
29-
res <- pure $ runIdentity (runRWST (loop 10000) "" 0)
40+
res1 <- pure $ runIdentity (runRWST (loop 1000000) "" 0)
3041
t2 <- t
31-
print $ "RWST.state: " ++ show res.state
32-
print $ "RWST.log: " ++ show res.log
33-
print $ "t2 - t1 = " ++ show (t2 - t1)
42+
print $ "RWST: " ++ show (t2 - t1)
43+
t3 <- t
44+
res2 <- pure $ execState (loopState 1000000) 0
45+
t4 <- t
46+
print $ "StateT: " ++ show (t4 - t3)
3447

3548
foreign import t :: forall eff. Eff eff Number

0 commit comments

Comments
 (0)