Search This Blog

Sunday, April 12, 2026

A Simple Switch Makes Code Differentiable

Introduction

One game I like to play is decomposing code and algorithms as compositions of simpler pieces even when they already seem as simple as they could be.

One example is the observation that adjoint mode automatic differentiation isn't a separate algorithm to forward mode automatic differentiation but a composition of forward mode and transposition. I talked about this in an old paper of mine Two Tricks for the Price of One and it resurfaces more recently in Jax: You Only Linearize Once.

Another example is the REDUCE algorithm used to differentiate a class of stochastic process despite making hard decisions when sampling from a distribution. It turns out this algorithm is nothing but ordinary everyday importance sampling but generalized to probabilities lying in a non-standard algebraic structure. I describe it in a blog article here and you can find out more about extending beyond the non-negative reals in a paper by Abramsky and Brandenberger. You sort of don't have to lift a finger to implement REDUCE - it just happens "for free" when you switch algebraic structure.

Here's a teeny tiny example of another "for free" method that just appears when you switch types.

The Good Old Vector Space Monad

First let me write the same "vector space" monad that I've written many times before around here:
> module Main where

> import qualified Data.Map.Strict as Map
> import Control.Applicative
> import Control.Monad

> newtype V s a = V { unV :: [(a, s)] }
>   deriving (Show)

> instance Functor (V s) where
>   fmap f (V xs) = V [ (f a, s) | (a, s) <- xs ]

> instance Num s => Applicative (V s) where
>   pure x = V [(x, 1)]
>   mf <*> mx = do { f <- mf; x <- mx; return (f x) }

> instance Num s => Monad (V s) where
>   (V xs) >>= f = V [ (b, s * t) |
>                      (a, s) <- xs, (b, t) <- unV (f a) ]

> instance Num s => Alternative (V s) where
>   empty = V []
>   (<|>) = add

> instance Num s => MonadPlus (V s) where
>   mzero = empty
>   mplus = (<|>)

> normalize :: (Ord a, Num s, Eq s) => V s a -> V s a
> normalize (V xs) = V $ Map.toList $
>                 Map.filter (/= 0) $ Map.fromListWith (+) xs

> scale :: Num s => s -> V s a -> V s a
> scale c (V xs) = V [ (a, c * s) | (a, s) <- xs ]

> add :: Num s => V s a -> V s a -> V s a
> add (V xs) (V ys) = V (xs ++ ys)

Dictionaries

One of the challenges in machine learning is to replace standard algorithms that make hard decisions with methods that are soft and squishy so they're amenable to being differentiated. An example might be building dictionaries of key-value pairs. So let's write a line of code to insert into a dictionary:
> insert' :: [(k, v)] -> k -> v -> [(k, v)]
> insert' dict k v = (k, v) : dict
It's not very clever, but it does function as a dictionary builder. Let's write it in a more general way using the fact that the list functor is applicative:
> insert :: Alternative m => m (k, v) -> k -> v -> m (k, v)
> insert dict k v = pure (k, v) <|> dict
We can go ahead and insert some entries:
> main :: IO ()
> main = do
>     let dict1 = empty :: [(Int, Int)]
>     let dict2 = insert dict1 0 1
>     let dict3 = insert dict2 1 0
>     print dict3

DeltaNet

So how can we rewrite that to make it more amenable to differentiation? We don't need to! It's already general enough. We just switch to using V instead of []:
>     let dict4 = empty :: V Double (Int, Int)
>     let dict5 = insert dict4 0 1
>     let dict6 = insert dict5 1 0
>     print $ normalize dict6
What has happened is that we've replaced the dictionary update with
D' = D + k⊗v
Addition, tensor product, these things are more amenable to differentiation than appending to a list. When using the vector space monad (or applicative) the (,) operator plays a role more like tensor product. Note that we're also no longer limited to inserting basis elements, we can use any suitably typed vectors:
>     let dict7 = dict6 <|>
>          (0.5 `scale` pure (0, 0) <|> 0.5 `scale` pure (1, 1))
>     print $ normalize dict7
This is the key ingredient in the update rule in DeltaNet, described in Parallelizing Linear Transformers with the Delta Rule over Sequence Length at the start of section 2.2.

Can we do this to other algorithms

These dictionaries aren't very smart. Could we do a similar trick with a binary tree, or a hierarchical binary tree of lists? One way of approaching the binary comparisons turns this into a kind of mixture of experts model, but I'm not sure it's a good place to have MoE. Maybe there's another way?

Monday, March 30, 2026

"What does it take to be a hero?" revisited

Biased posteriors the hard way

I was previously interested to see how die rolls in an RPG appear when conditioned on you having survived an unlikely situation. As might have been predicted, if the die rolls contribute to that survival in a largely additive way, for example by being damage scored against a large opponent, then the posterior distribution of the rolls looks exponentially tilted.

Biased posteriors an easier way

But the only virtue of the brute force Monte Carlo method I used was that it was easy to code. It's computationally wasteful. So I wrote a much more performant DSL in Python which I have put on github.

It uses numpy to achieve tolerable numerical performance, but in addition it uses two techniques beyond brute force to make it usable.

One challenge with a probabilistic language is to manage state.

First there's state "in the past": If you're computing probabilities that are sums of many large intermediate states, for example 100 die rolls, you run the risk of running foul of combinatorial explosion.

If you write a loop like:

for i in range(100): t += d(6)

you want to be sure that the += operation erases history (ie. previous values of t) so you aren't tracking all 6^100 individual histories. In this case it's easy but in other cases it might not be so obvious that you have unneeded state lying around. So the code has a simple (and incomplete) backward liveness pass to insert deletions of data that won't be used again. Whenever state is deleted, you can merge histories that are now indistinguishable.

And then there is state "in the future": sometimes you'd like to compute probabilities of data structures like lists but materializing a list results in state that can cause combinatorial explosion. So I support Python style generators allowing you to generate data lazily - for example permutations of cards. So we can bring into existence state just before we need it.

Auto batching

There's another important technique I used: this code uses brute force (though at this point maybe I should stop calling it brute force) so there are many states, each corresponding to a possible set of values for some numpy objects. And we often want to perform a numpy operation for each of these values. We don't need to loop. In many cases a parameterised family of numpy operations is in fact a single numpy operation. This kind of transformation is ubiquitous in GPU computing. So we can interpret the following &&D fight, summing over the combinatorialy large number of ways it could happen, in a few seconds:

@d9.dist
def f():
    # Brachiosaurus (Monster Manual 1e p. 24)
    hp1 = lazy_sum(36 @ d(8))
    # Tyrannosaurus Rex (Monster Manual 1e p.28)
    hp2 = lazy_sum(18 @ d(8))

    for i in range(14):
        print("round", i)
        if hp1 > 0 and d(20) > 1:
            hp2 = max(0, hp2 - lazy_sum((x for x in 5 @ d(4))))

        if hp2 > 0:
            # Two claws...
            if d(20) > 1:
                hp1 -= d(6)
            if d(20) > 1:
                hp1 -= d(6)
            # ...and a bite
            if d(20) > 1:
                hp1 -= lazy_sum((x for x in 5 @ d(8)))
            hp1 = max(hp1, 0)

    win1 = hp2 == 0
    win2 = hp1 == 0

    return win1, win2

Besides its own test suite I also used a large number of questions on the RPG Stack Exchange to build a library of examples for testing.

Let's work in a general semiring

Of mathematical interest: most of the probability computations take place in a semiring. So most of the numerical computing is simply addition and multiplication. When that's all you're doing, there is the well known technique for working with large integers where you work modulo p[i] for some array of primes and only at the end reconstruct your final result using the Chinese remainder theorem. Less will known is that this works for rationals also. (I conjectured this was true, started deriving it myself, and then learnt there are published methods.) This means we can work with exact rational arithmetic using numpy without the need for a bignum library. This code turns out being related to provenance semirings as it effectively becomes a simple database tracking the provenance of each record.

I originally wrote this code to target GPUs. On my Mac, numpy turned out to be comparable in speed to PyTorch and way faster than TensorFlow. I think this is because those libraries are optimised around data of fairly fixed shape passing through fixed pipelines whereas my code is very ad hoc. I've a feeling a few custom kernels would speed it up a lot. (I may be wrong about this but I do know I can write CUDA/Metal code directly that is many times faster than some of my dice-nine examples.)

Not just Python

And one final note. This is a deeply embedded DSL. I use Python as a host to give me an AST that I interpret. This isn't simply overloading of Python operators.

Blog Archive