Saturday, July 04, 2009

A Monad for Combinatorial Search with Heuristics

Haskell provides a great way to perform combinatorial searching with backtracking: the list monad. Do-notation provides a nice DSL that makes it easy to express the trying out of different possibilities. But the list monad only performs a simple-minded walk through all of the alternatives giving little opportunity to direct that walk. In particular, it's not easy to provide heuristics to say things like "try this alternative first but if it starts going badly consider this alternative too". This post contains a monad that gives a simple scheme to allow programmers to direct searches in this way.

First the Haskell administrativia...


> import Data.Char
> import Control.Monad
> import Data.Monoid
> import Data.List


When using the list monad, a list is interpreted as a list of candidates in a search. The join function for this monad takes a list of lists of candidates and flattens it into a list of candidates. This is all the list monad really does: you write code that generates new candidates from old, and the >>= function applies this code to all of the candidates it knows about and then flattens this back out to a single list of candidates. Importantly it does this in a lazy way so that you only need look at candidates as they are generated.

This new monad will keep slightly more information: each candidate will have a 'penalty' value attached to it saying how attractive a candidate it is. Candidates with score 0 will be tried first, and those with score n will be tried after those with lower scores. We can represent a collection of candidates and their scores simply as a list of lists. The first list in the list will have those with score 0, the second will have those with score 1 and so on. We'll call these lists penalty lists and the positions within those lists slots.

Here's the definiton of the penalty list type:


> data PList a = P { unO :: [[a]] } deriving (Show,Eq)


It's a functor in a straightforward way:


> instance Functor PList where
> fmap f (P xs) = P (fmap (fmap f) xs)


The rule we'll adopt is that if you're trying a combination of two candidates then the penalty associated with the combination is the sum of the penalties of the individual objects. To implement this we need an alternative version of the join operation. If we have a penalty list of penalty lists and we have an element in the mth slot in the nth penalty sublist then we want it to end up in the (m+n)th slot in the final penalty list. Within a slot we can just order the elements just like in the original list monad.


> headm :: Monoid m => [m] -> m
> headm (a:as) = a
> headm [] = mempty

> tailm :: Monoid m => [m] -> [m]
> tailm (a:as) = as
> tailm [] = []

> zipm :: Monoid m => [[m]] -> [m]
> zipm ms | all null ms = []
> zipm ms = let
> heads = map headm ms
> tails = map tailm ms
> h = mconcat heads
> t = zipm (filter (not . null) tails)
> in h : t

> instance Monad PList where
> return x = P [[x]]
> x >>= f = let P xs = (fmap (unO . f) x) in P (join xs) where
> join [] = []
> join (m:ms) = let
> part1 = zipm m
> part2 = join ms
> in headm part1 : zipm [tailm part1,part2]


Explaining how join is implemented would take many words so I hope this picture of the computation of an example will do instead. I used the Monoid class simply to avoid directly referring to one level of nesting of brackets. It is intended to be a proper implementation of a Monad satisfying the three monad laws but I haven't proved this and it's possible that it occasionally leaves trailing empty lists around - which have no impact on search results.



> instance MonadPlus PList where
> mzero = P []
> mplus (P xs) (P ys) = P (zipm [xs,ys])


We can use this much like the list monad. First it will search for possibilities with zero penalty. When these are exhausted it'll backtrack to the last place where it can start finding possibilities with penalty 1. Then it'll try penalty 2 and so on. Importantly it manages to do this lazily so that we don't explore penalty n+1 until we've finished penalty n.

So now we can start using it. We'll hunt for Pythagorean triples by simply hunting through all of the triples of integers. But we'll try to find solutions where the sum of the integers is as small as possible. So as list of candidate integers we use P [[1],[2],[3]...]. In other words, the integer n has penalty n-1. Here's the code:


> ex1 = do
> x <- P $ map (\x -> [x]) [1..]
> y <- P $ map (\y -> [y]) [1..]
> z <- P $ map (\z -> [z]) [1..]
> guard $ x*x+y*y==z*z
> return $ (x,y,z)


Of course we wouldn't really search for Pythagorean triples this way. This is just an illustration of how to use the code. But note, crucially, that the equivalent code using the regular list monad would give us back no solutions. It'd start with x=1 and y=1 and then go off to infinity finding candidates for z. So as a side effect the penalty list allows us to tame some infinite searches.

Anyway, that was a simple numerical example. But this monad can be used with much more complex kinds of search. In fact it almost serves as a drop-in replacement for the list monad. This is a really nice example of the way separation of concerns is easy in Haskell. The task of generating candidates for search can easily be separated from the task of selecting from those candidates, even though the operations are highly interleaved during execution.

So here's a more complex example: writing a parser that can tolerate errors without running into combinatorial explosion. The idea is that we associate a penalty with each error. The penalty will make the parser run on the assumption of no errors until it can no longer parse, and then it'll backtrack on the assumption of one error until that assumption is no longer tenable and so on. We can liberally sprinkle 'erroneous' parsings throughout our code confident that these branches will only be taken in the event that an error-free parsing can't be found.

Firstly, here's a penalty list that we can use to introduce a penalty of just 1.


> penalty :: PList ()
> penalty = P [[],[()]]


If we stick that in the code path then anything following acquires a penalty of 1.

Now we can write a parser. We can implement Hutton's parser in his monad parsers paper with very little modification. We simply replace the usual list with the penalty list and do away with the +++ operator to allow it to be a bit more liberal about backtracking. Here's the parser type:


> newtype Parser a = Parser (String -> PList (a,String))


We could have parameterised that with the underlying monad so that we could have parsers with a choice of search strategy.

The rest is a lot like in Hutton's paper:


> parse (Parser f) x = f x

> instance Monad Parser where
> return a = Parser (\cs -> P [[(a,cs)]])
> p >>= f = Parser (\cs -> do
> (a,cs') <- parse p cs
> parse (f a) cs')

> instance MonadPlus Parser where
> mzero = Parser (\cs -> mzero)
> p `mplus` q = Parser (\cs -> parse p cs `mplus` parse q cs)

> item :: Parser Char
> item = Parser (\cs -> case cs of
> "" -> mzero
> (c:cs) -> P [[(c,cs)]])

> sat :: (Char -> Bool) -> Parser Char
> sat p = do
> c <- item
> if p c then return c else mzero

> char :: Char -> Parser Char
> char c = sat (c ==)


Now for a simple parsing problem. We'll parse simple arithmetical expressions a lot like in Hutton's paper. But I'm going to tolerate two kinds of error:
1. The shift key doesn't always work so occasionally a shifted or unshifted version of a character may appear and
2. parentheses are occasionally left out by the clumsy user.

Now we can code up a simple grammar for this. First the mapping between shifted and unshifted characters (on a Mac US keyboard):


> lowers = "1234567890-=/"
> uppers = "!@#$%^&*()_+?"
> lower x = lookup x (zip uppers lowers)
> upper x = lookup x (zip lowers uppers)

> upperChar x = case upper x of
> Nothing -> mzero
> Just y -> char y >> return x

> lowerChar x = case lower x of
> Nothing -> mzero
> Just y -> char y >> return x


A version of penalty wrapped for the parser monad:


> avoid :: Parser ()
> avoid = Parser $ \cs -> do
> penalty
> return ((),cs)


Reading keys on the assumption that the shift key may have failed:


> keyChar x = char x `mplus` (avoid >> upperChar x) `mplus` (avoid >> lowerChar x)

> digit = do
> x <- foldl mplus mzero (map keyChar "0123456789")
> return (fromIntegral (ord x-ord '0'))

> number1 :: Integer -> Parser Integer
> number1 m = return m `mplus` do
> n <- digit
> number1 (10*m+n)

> number :: Parser Integer
> number = do
> n <- digit
> number1 n

> chainl :: Parser a -> Parser (a -> a -> a) -> a -> Parser a
> chainl p op a = (p `chainl1` op) `mplus` return a
> chainl1 :: Parser a -> Parser (a -> a -> a) -> Parser a
> p `chainl1` op = do {a <- p; rest a}
> where
> rest a = (do
> f <- op
> b <- p
> rest (f a b)) `mplus` return a


Optional parentheses:


> shouldHave c = keyChar c `mplus` (avoid >> return c)


And the main part of the expression grammar:


> expr = term `chainl1` addop
> term = monomial `chainl1` mulop
> monomial = factor `chainl1` powop
> factor = number `mplus` do {shouldHave '('; n <- expr; shouldHave ')'; return n}
> powop = keyChar '^' >> return (^)
> addop = do {keyChar '+'; return (+)} `mplus` do {keyChar '-'; return (-)}
> mulop = do {keyChar '*'; return (*)} `mplus` do {keyChar '/'; return (div)}


Match the end of a string:


> end :: Parser ()
> end = Parser $ \cs ->
> if null cs then P [[((),"")]] else mzero


We can test it out with:


> completeExpr = do
> n <- expr
> end
> return n

> ex2 = parse completeExpr "2^(1+3"


When we run this we get no error-free parsing but we do get 3 readings with one error. One comes from reading the '(' as 9, one comes from inserting the missing ')' at the end and one comes from inserting ')' after '1'. Note that even for complex expressions we'll quickly find a 1- or 2-error parsing. For the regular list monad we might never get a parsing because there are an infinite number of ways of inserting parentheses.

Anyway, that was just a toy parsing problem. But a more complex application comes to my mind. Some written languages are tricky to parse because their orthography doesn't fully capture the phonetics of the original language, because there are few or no indicators of sentence or even word breaks, and because they have numerous optional orthographic and grammatical rules and use a script whose individual characters are occasionally hard to reliably identify. In such a situation it's good to have a parser driven by heuristics about what is likely to be intended and the penalty list monad might serve well. Here's an example of such a language.

Update: I forgot to add some connections to previous monads I've talked about:

  1. PList is a variation of the convolution monad I described here. It deals with the "wrong category" aspect so it is a true Haskell monad. Penalty lists form some kind of dual to the convolution comonad.
  2. It has much in common with this monad. That monad doesn't do anything smart about ordering searches but it does have the neat ability to 'fuse' different branches of a search so that different ways to arrive at the same place don't add to the combinatorial explosion. It's good for searches where you want to know what the minimum penalty is to get somewhere, but don't care what the best path actually is.

Also, in response to a comment on #haskell I've made the join example more complex so it's easier to generalise from it.