Saturday, March 17, 2012

Overloading Python list comprehension

Introduction
Python is very flexible in the way it allows you to overload various features of its syntax. For example most of the binary operators can be overloaded. But one part of the syntax that can't be overloaded is list comprehension ie. expressions like [f(x) for x in y].

What might it mean to overload this notation? Let's consider something simpler first, overloading the binary operator +. The expression a+b is interpreted as a.__add__(b) if a is of class type. So overloading + means nothing more than writing a function. So if we can rewrite list comprehensions in terms of a function (or functions) then we can overload the notation by providing alternative definitions for those functions. Python doesn't provide a facility for doing this directly, but we can at least think about what it might mean to do this. Later we'll see how to tweak the Python interpreter to make it possible.

map
Consider the expression
[a for x in y]
Here the single letter variables are 'metavariables' representing fragments of Python code. To a good approximation this is equal to:
map(lambda x: a, y)
(BTW Everything I say here is "to a good approximation". Python is an incredibly complex language and I'm not good enough at it to make any categorical statements about when one fragment of code is the same as another.)

So it's tempting to see list comprehensions as syntactic sugar for map, in which case one approach to overloading comprehension is to consider interpreting it in terms of replacements for map. But this isn't a very powerful overloading. It just gives us a slightly different way to write something that's already straightforward.

concatMap
Another reason for not simply seeing list comprehension in terms of map is that nested list comprehensions need another operation. Consider
[(y, z) for y in [1, 2] for z in ['a', 'b']]
This isn't quite the same as
[[(y, z) for z in ['a', 'b']] for y in [1, 2]]
but it's close. The latter produces nested lists whereas the first gives one flat list. We can think of nested comprehensions as applying a flattening operation. Let's use list comprehension to implement flattening:
def concat(xs):
 return [y for x in xs for y in x]
We now write our nested comprehension as:
concat([[(y, z) for z in ['a', 'b']] for y in [1, 2]])
We know how to write non-nested comprehensions using map so we get:
concat(map(lambda y: [(y, z) for z in ['a', 'b']], [1, 2]))
And rewriting the inner comprehension we get:
concat(map(lambda y: map(lambda z: (y, z), ['a', 'b']), [1, 2]))
Every time we add another level of nesting we're going to need another concat. But the innermost map doesn't have a concat. Purely for reasons of symmetry we can ensure every map has a concat by enclosing the innermost element as a singleton list:
concat(map(lambda y: concat(map(lambda z: [(y, z)], ['a', 'b'])), [1, 2]))
Every map has a concat so we can simplify slightly. Let's define:
def concatMap(f, xs):
 return [f(y) for x in xs for y in x]

def singleton(x):
 return [x]
Our expression becomes:
concatMap(lambda y: concatMap(lambda z: singleton((y, z)), ['a', 'b']), [1, 2])
Importantly we've completely rewritten the comprehension in terms of concatMap and singleton. By changing the meaning of these functions we can change the meaning of comprehension notation, or at least we could if the Python interpreter defined comprehension this way. It doesn't, but we can still reason about it. Although any comprehension that doesn't use ifs can be rewritten to use these functions, I won't give a formal description of the procedure. Instead I'll provide code to perform the rewrite later. While I'm at it, I'll also handle the ifs.

Laws
Freely redefining singleton and concatMap to redefine comprehension could get weird. If we're going to redefine them we should at least try to define them so that list comprehension still has some familiar properties. For example, for y a list we usually expect:
y == [x for x in y]
In other words
y == concatMap(lambda x: singleton(x), y)
At this point I could give a whole bunch more laws but it's time to own up.

Monads
A pair of functions singleton and concatMap, along with a bunch of laws, are essentially the same thing as a monad. In Haskell, concatMap is usually called bind and singleton is called return. What I've done here is show how Wadler's Comprehending Monads paper might look like in Python. Haskell has specialised monad notation built into its grammar. But what's less well known is that so does Python! The catch is that although the grammar is right, the semantics can't be generalised beyond lists.

Monad-Python
One great thing about Python is that there seem to be libraries for working with every aspect of Python internals. So it's fairly easy to write a simple Python interpreter that rewrites list comprehensions to use singleton and concatMap. I've placed the source on github. Use mpython.py instead of python as your interpreter. I've tested it with Python 2.6 and 2.7.

When using mpython, list comprehension uses whatever definitions of __mapConcat__ and __singleton__ are currently in scope. By default they are the definitions I gave above so we get something close to the usual list comprehension.

An example of the kind of code you can run with mpython.py is:
import math

def __concatMap__(k, m):
  return lambda c:m(lambda a:k(a)(c))

def __singleton__(x):
  return lambda f:f(x)

def callCC(f):
  return lambda c:f(lambda a:lambda _:c(a))(c)

def __fail__():
  raise "Failure is not an option for continuations"

def ret(x):
  return __singleton__(x)

def id(x):
  return x

def solve(a, b, c):
  return callCC(lambda throw: [((-b-d)/(2*a), (-b+d)/(2*a))
                               for a0 in (throw("Not quadratic") if a==0 else ret(a))
                               for d2 in ret(b*b-4*a*c)
                               for d in (ret(math.sqrt(d2)) if d2>=0 else throw("No roots"))
                              ])

print solve(1, 0, -9)(id)
print solve(1, 1, 9)(id)
print solve(0, 1, 9)(id)
I have defined our functions so that comprehension syntax gives us the continuation monad. This makes continuation passing style relatively painless in Python. (At least easier than chaining many lambdas.) I have then defined callCC to be similar to its definition in Haskell. There are many uses for callCC including the implementation of goto. Above I use it in a trivial way to throw exceptions.

Conclusion
My script mpython.py is a long way from an industrial strength interpreter and I'm not proposing the above as an extension to Python. My goal was simply to show how Haskell-style monads are not as alien to Python as you might think. In fact, it's reasonable to say that Python already supports one flavour of specialised monad syntax. Most users don't realise it as such because it has been hard-wired to work with just one monad, lists.

BTW if you attempt to implement all of the other Haskell monads you'll find that Haskell behaves a little differently because of its laziness. You can recover some of that laziness by careful use of continuations in Python. But I've no time to go into that now.

8 comments:

George Giorgidze said...

Nice to see that, with a little bit of effort, one can enable the use Python's comprehension notation for writing monadic code.

Not so long ago, my colleagues and myself worked on a Haskell extension that allows the use of the comprehension notation for writing monadic code. In addition to the generator and filter clauses (covered in this post) the extension generalises parallel (zip) comprehensions and SQL-like comprehensions to monads.

I think parallel and SQL-like comprehensions would benefit Python as well.

The paper that describes the extension and gives several examples of its usage can be downloaded from the following link.

http://db.inf.uni-tuebingen.de/files/giorgidze/haskell2011.pdf

The extension has been implemented in GHC and is available since the version 7.2.

sigfpe said...

Hey @George!

Thanks for putting the monads back into Haskell comprehensions! I was so disappointed when I first read about them years ago, and then discovered they'd been removed.

Anonymous said...

There's a typo in your first use of "__concatMap__" in the main text.

sigfpe said...

@Anonymous,

As usual with me, see the actual tested code for the version with fewer typos :-)

Shin no Noir said...

I once toyed with the idea of introducing monadic syntax in Python by decompiling generators, but never managed to find the spare time to work on it. The syntax I had in mind was as follows:

from monad import do

do(x+y
for x in foo
for y in bar)

The do function would decompile the generator argument and rewrite it. The advantage is that you could still run your program using python instead of loading it differently.

(The disadvantage is that decompiling is more fragile and might break across Python versions and implementations)

Antti Rasinen said...

I would like to understand what is happening in the last example. However, the use of multiple chained anonymous functions makes it very difficult to understand the relationship between different operations. May I ask you to rewrite the example using the lameness of named variables and functions?

Thomas said...

There is a bug in the function concatMap in the text.

concatMap should map first, and concat second.

The test code does not have this bug.

Barnaby Robson said...

This is really cool. There is so much to discover in this post. 1. learning about how to reprogram the python interpreter at run time. 2. actually getting real monads in python, this is probably the best implementation so far on the web ! 3. learning about the continuation monad (in python!).

Blog Archive