Saturday, September 16, 2006

Local and global side effects with monad transformers

There is an annual puzzle event that 90 or so people attend. Each designs a puzzle and manufactures 90 copies of it which are then shared with the other participants. All 90 then get to go home with 90 puzzles.

Anyway, an ex-coworker attends this event and I had a chance to play with his puzzle. It was one of those fitting blocks together types puzzles. It was ingeniously designed but it seemed pretty clear to me that it required a lot of combinatorial searching. So I decided to write a program in Haskell to solve it using the List monad to enable simple logic programming.

But my code had a slight problem. It worked through the puzzle and found solutions, but (1) it didn't log the steps it took to achieve those solutions and (2) I couldn't count how many combinations it searched to find the solutions. Both of these could be thought of as side-effects, so the obvious thing to do is use a monad to track these things. But there was a catch - I was already using a monad - the List monad. When that happens there's only one thing for it - using a monad transformer to combine monads.

There are two distinct ways to combine a monad with the List monad and I needed both.

Anyway, this is literate Haskell. I'll assume you're vaguely familiar with using the List monad for logic programming. I also assume you're familiar with MonadPlus though that's not something I've written about here. And I couldn't get this stuff to work in Hugs, so use ghc. Here's some code:


> import Control.Monad.List
> import Control.Monad.State
> import Control.Monad.Writer


I'm not going to describe the original puzzle now. Instead I'm going to look at an almost trivial logic problem so that we can concentrate on the monad transforming. The puzzle is this: find all the possible sums of pairs of integers, one chosen from the set {1,2} and the other from the set {2,4}, where the sum is less than 5. Here's a simple implementation.


> test1 :: [Integer]

> test1 = do
> a <- [1,2]
> b <- [2,4]
> guard $ a+b<5
> return $ a+b

> go1 = test1


Run go1 and you should get the result [3,4]. But that's just the sums. What were the pairs of integers that went into those sums? We could simply return (a,b,a+b), but in more complex problems we might want to log a complex sequence of choices and that would entail carrying all of that information around. What we'd like is to simply have some kind of log running as a side effect. For this we need the Writer monad.

If you cast your mind back, monad transformers layer up monads a bit like layers of onion skin. What we want is to wrap a List monad in a Writer monad. We do this using the WriterT monad transformer. All we have to do is add a tell line to our code, and use 'lift' to pull the items in the list monad out from one layer of onion. Here's what the code looks like:


> test2 :: WriterT [Char] [] Integer

> test2 = do
> a <- lift [1,2]
> b <- lift [2,4]
> tell ("trying "++show a++" "++show b++"\n")
> guard $ a+b<5
> return $ a+b


To get the final result we need to use runWriterT to peel the onion:


> go2 = runWriterT test2


Execute go2 and we get a list of pairs of sums and logged messages.

There's an important point to note here: we have one log per sum, so the logs are 'local'. What if we want a 'global' side effect such as a count of how many combinations were tried, regardless of whether they succeded or failed? An obvious choice of monad to count attempts is the State monad, but to make its effects 'global' we now need to make State the inner monad and make List provide the outer layer of skin. We're wrapping the opposite way to in the previous example. And now there's a catch. We use a line like a <- [1,2] to exploit the List monad. But now we no longer have a List monad, instead we have a ListT (State Integer) monad. This means that [1,2] is not an object in this monad. We can't use 'lift' either because the inner monad isn't List. We need to translate our lists into the ListT (State Integer) monad.

We can do slightly better, we can translate a list from the List monad into any other instance of MonadPlus. Remember that return x :: [X] is the same as [x] and x ++ y is the same as x `mplus` y :: [X]. For example [1,2] == (return 1) `mplus` (return 2). The latter only uses functions from the MonadPlus interface to build the list, and hence it can be used to build the equivalent of a List in any MonadPlus. To mplus a whole list we use msum leading to the definition:


> mlist :: MonadPlus m => [a] -> m a
> mlist = msum . map return


As a function [a] -> [a], mlist is just the identity. Now we're ready to go:


> test3 :: ListT (State Integer) Integer

> test3 = do
> a <- mlist [1,2]
> b <- mlist [2,4]
> lift $ modify (+1)
> guard $ a+b<5
> return $ a+b

> go3 = runState (runListT test3) 0


Run go3 to see the result. Note we had to lift the modify line because the State monad is the inner one.

And now we have one more problem to solve: bouth logging and counting simultaneously:


> test4 :: WriterT [Char] (ListT (State Integer)) Integer

> test4 = do
> a <- lift $ mlist [1,2]
> b <- lift $ mlist [2,4]
> tell ("trying "++show a++" "++show b++"\n")
> lift $ lift $ modify (+1)
> guard $ a+b<5
> return $ a+b

> go4 = runState (runListT $ runWriterT test4) 0


That's it!

We can carry out a cute twist on this. By swapping the innermost and outermost monads we get:


> test5 :: StateT Integer (ListT (Writer [Char])) Integer

> test5 = do
> a <- lift $ mlist [1,2]
> b <- lift $ mlist [2,4]
> lift $ lift $ tell ("trying "++show a++" "++show b++"\n")
> modify (+1)
> guard $ a+b<5
> return $ a+b

> go5 = runWriter $ runListT $ runStateT test5 0


go5 returns a local count of how many combinations were required for each problem, and the Writer monad now records every 'try' in one long log.

One last thing: you don't need to explicitly 'lift' things - the monad transformers have a nice interface that automatically lifts some operations. (You may need a recent Haskell distribution for this, it fails for older versions.)


> test6 :: WriterT [Char] (ListT (State Integer)) Integer

> test6 = do
> a <- mlist [1,2]
> b <- mlist [2,4]
> tell ("trying "++show a++" "++show b++"\n")
> modify (+1)
> guard $ a+b<5
> return $ a+b

> go6 = runState (runListT $ runWriterT test6) 0


It'd be cool to get rid of the mlist too. Maybe if the Haskell parser was hacked so that [1,2] didn't mean 1:2:[] but instead meant (return 1) `mplus` (return 2) like the way Gofer interprets list comprehensions in any monad. (For all I know, Gofer already does exactly what I'm suggesting.)

One thing I should add - these monad transformers really kill performance. The puzzle solver I wrote no longer gives me any solutions in the few minutes that it used to...

PS I just made up the mlist thing. There may be a better way of doing this that I don't know about. I was surprised it wasn't already in the Control.Monad library somewhere. mlist is kind of a homomorphism between MonadPlusses and I think it might make the List MonadPlus an initial object in some category or other - but that's just speculation right now.

Update: I fixed the non-htmlised <'s. I think it takes more time to convert to blogger-compatible html than to write my posts! Also take a look at this. My mlist corresponds to their liftList - and now I know I wasn't completely off the rails writing mlist.

4 comments:

  1. One speed-killer is the mlist function which is applied every single time a branch in the logic proram is taken. Also, most lines of code written in monadic style have implicit >>= functions in them. As you layer up monad transformers the implementation of >>= gets more and more complex.

    ReplyDelete
  2. Hmmm...it's just become harder to proofread this. Not only do I have to read it here, but I also need to view it over at Planet Haskell. I see that the <- 'operators' are coming out differently over there...

    ReplyDelete
  3. Great post.

    I copy/pasted your code, function by function and played around a bit between each step to make sure I understood what was happening. by the time I got to test6 I run into problems. I seem to only have a half-recent distribution of Haskell (ghc 6.6) because I could only get rid of half of the lifts :-) I still needed to lift the tell twice.

    ReplyDelete
  4. magnus,

    I personally wouldn't worry too much about not being able to remove the lifts. If you use implicit lifts then you can't use two state transformers of the same type simultaneosuly which is pretty non-orthogononal. So I suggest always using the lifts.

    ReplyDelete