Saturday, March 17, 2012

Overloading Python list comprehension

Introduction
Python is very flexible in the way it allows you to overload various features of its syntax. For example most of the binary operators can be overloaded. But one part of the syntax that can't be overloaded is list comprehension ie. expressions like [f(x) for x in y].

What might it mean to overload this notation? Let's consider something simpler first, overloading the binary operator +. The expression a+b is interpreted as a.__add__(b) if a is of class type. So overloading + means nothing more than writing a function. So if we can rewrite list comprehensions in terms of a function (or functions) then we can overload the notation by providing alternative definitions for those functions. Python doesn't provide a facility for doing this directly, but we can at least think about what it might mean to do this. Later we'll see how to tweak the Python interpreter to make it possible.

map
Consider the expression
[a for x in y]
Here the single letter variables are 'metavariables' representing fragments of Python code. To a good approximation this is equal to:
map(lambda x: a, y)
(BTW Everything I say here is "to a good approximation". Python is an incredibly complex language and I'm not good enough at it to make any categorical statements about when one fragment of code is the same as another.)

So it's tempting to see list comprehensions as syntactic sugar for map, in which case one approach to overloading comprehension is to consider interpreting it in terms of replacements for map. But this isn't a very powerful overloading. It just gives us a slightly different way to write something that's already straightforward.

concatMap
Another reason for not simply seeing list comprehension in terms of map is that nested list comprehensions need another operation. Consider
[(y, z) for y in [1, 2] for z in ['a', 'b']]
This isn't quite the same as
[[(y, z) for z in ['a', 'b']] for y in [1, 2]]
but it's close. The latter produces nested lists whereas the first gives one flat list. We can think of nested comprehensions as applying a flattening operation. Let's use list comprehension to implement flattening:
def concat(xs):
 return [y for x in xs for y in x]
We now write our nested comprehension as:
concat([[(y, z) for z in ['a', 'b']] for y in [1, 2]])
We know how to write non-nested comprehensions using map so we get:
concat(map(lambda y: [(y, z) for z in ['a', 'b']], [1, 2]))
And rewriting the inner comprehension we get:
concat(map(lambda y: map(lambda z: (y, z), ['a', 'b']), [1, 2]))
Every time we add another level of nesting we're going to need another concat. But the innermost map doesn't have a concat. Purely for reasons of symmetry we can ensure every map has a concat by enclosing the innermost element as a singleton list:
concat(map(lambda y: concat(map(lambda z: [(y, z)], ['a', 'b'])), [1, 2]))
Every map has a concat so we can simplify slightly. Let's define:
def concatMap(f, xs):
 return [f(y) for x in xs for y in x]

def singleton(x):
 return [x]
Our expression becomes:
concatMap(lambda y: concatMap(lambda z: singleton((y, z)), ['a', 'b']), [1, 2])
Importantly we've completely rewritten the comprehension in terms of concatMap and singleton. By changing the meaning of these functions we can change the meaning of comprehension notation, or at least we could if the Python interpreter defined comprehension this way. It doesn't, but we can still reason about it. Although any comprehension that doesn't use ifs can be rewritten to use these functions, I won't give a formal description of the procedure. Instead I'll provide code to perform the rewrite later. While I'm at it, I'll also handle the ifs.

Laws
Freely redefining singleton and concatMap to redefine comprehension could get weird. If we're going to redefine them we should at least try to define them so that list comprehension still has some familiar properties. For example, for y a list we usually expect:
y == [x for x in y]
In other words
y == concatMap(lambda x: singleton(x), y)
At this point I could give a whole bunch more laws but it's time to own up.

Monads
A pair of functions singleton and concatMap, along with a bunch of laws, are essentially the same thing as a monad. In Haskell, concatMap is usually called bind and singleton is called return. What I've done here is show how Wadler's Comprehending Monads paper might look like in Python. Haskell has specialised monad notation built into its grammar. But what's less well known is that so does Python! The catch is that although the grammar is right, the semantics can't be generalised beyond lists.

Monad-Python
One great thing about Python is that there seem to be libraries for working with every aspect of Python internals. So it's fairly easy to write a simple Python interpreter that rewrites list comprehensions to use singleton and concatMap. I've placed the source on github. Use mpython.py instead of python as your interpreter. I've tested it with Python 2.6 and 2.7.

When using mpython, list comprehension uses whatever definitions of __mapConcat__ and __singleton__ are currently in scope. By default they are the definitions I gave above so we get something close to the usual list comprehension.

An example of the kind of code you can run with mpython.py is:
import math

def __concatMap__(k, m):
  return lambda c:m(lambda a:k(a)(c))

def __singleton__(x):
  return lambda f:f(x)

def callCC(f):
  return lambda c:f(lambda a:lambda _:c(a))(c)

def __fail__():
  raise "Failure is not an option for continuations"

def ret(x):
  return __singleton__(x)

def id(x):
  return x

def solve(a, b, c):
  return callCC(lambda throw: [((-b-d)/(2*a), (-b+d)/(2*a))
                               for a0 in (throw("Not quadratic") if a==0 else ret(a))
                               for d2 in ret(b*b-4*a*c)
                               for d in (ret(math.sqrt(d2)) if d2>=0 else throw("No roots"))
                              ])

print solve(1, 0, -9)(id)
print solve(1, 1, 9)(id)
print solve(0, 1, 9)(id)
I have defined our functions so that comprehension syntax gives us the continuation monad. This makes continuation passing style relatively painless in Python. (At least easier than chaining many lambdas.) I have then defined callCC to be similar to its definition in Haskell. There are many uses for callCC including the implementation of goto. Above I use it in a trivial way to throw exceptions.

Conclusion
My script mpython.py is a long way from an industrial strength interpreter and I'm not proposing the above as an extension to Python. My goal was simply to show how Haskell-style monads are not as alien to Python as you might think. In fact, it's reasonable to say that Python already supports one flavour of specialised monad syntax. Most users don't realise it as such because it has been hard-wired to work with just one monad, lists.

BTW if you attempt to implement all of the other Haskell monads you'll find that Haskell behaves a little differently because of its laziness. You can recover some of that laziness by careful use of continuations in Python. But I've no time to go into that now.