Sunday, April 12, 2026

A Simple Switch Makes Code Differentiable

Introduction

One game I like to play is decomposing code and algorithms as compositions of simpler pieces even when they already seem as simple as they could be.

One example is the observation that adjoint mode automatic differentiation isn't a separate algorithm to forward mode automatic differentiation but a composition of forward mode and transposition. I talked about this in an old paper of mine Two Tricks for the Price of One and it resurfaces more recently in Jax: You Only Linearize Once.

Another example is the REDUCE algorithm used to differentiate a class of stochastic process despite making hard decisions when sampling from a distribution. It turns out this algorithm is nothing but ordinary everyday importance sampling but generalized to probabilities lying in a non-standard algebraic structure. I describe it in a blog article here and you can find out more about extending beyond the non-negative reals in a paper by Abramsky and Brandenberger. You sort of don't have to lift a finger to implement REDUCE - it just happens "for free" when you switch algebraic structure.

Here's a teeny tiny example of another "for free" method that just appears when you switch types.

The Good Old Vector Space Monad

First let me write the same "vector space" monad that I've written many times before around here:
> module Main where

> import qualified Data.Map.Strict as Map
> import Control.Applicative
> import Control.Monad

> newtype V s a = V { unV :: [(a, s)] }
>   deriving (Show)

> instance Functor (V s) where
>   fmap f (V xs) = V [ (f a, s) | (a, s) <- xs ]

> instance Num s => Applicative (V s) where
>   pure x = V [(x, 1)]
>   mf <*> mx = do { f <- mf; x <- mx; return (f x) }

> instance Num s => Monad (V s) where
>   (V xs) >>= f = V [ (b, s * t) |
>                      (a, s) <- xs, (b, t) <- unV (f a) ]

> instance Num s => Alternative (V s) where
>   empty = V []
>   (<|>) = add

> instance Num s => MonadPlus (V s) where
>   mzero = empty
>   mplus = (<|>)

> normalize :: (Ord a, Num s, Eq s) => V s a -> V s a
> normalize (V xs) = V $ Map.toList $
>                 Map.filter (/= 0) $ Map.fromListWith (+) xs

> scale :: Num s => s -> V s a -> V s a
> scale c (V xs) = V [ (a, c * s) | (a, s) <- xs ]

> add :: Num s => V s a -> V s a -> V s a
> add (V xs) (V ys) = V (xs ++ ys)

Dictionaries

One of the challenges in machine learning is to replace standard algorithms that make hard decisions with methods that are soft and squishy so they're amenable to being differentiated. An example might be building dictionaries of key-value pairs. So let's write a line of code to insert into a dictionary:
> insert' :: [(k, v)] -> k -> v -> [(k, v)]
> insert' dict k v = (k, v) : dict
It's not very clever, but it does function as a dictionary builder. Let's write it in a more general way using the fact that the list functor is applicative:
> insert :: Alternative m => m (k, v) -> k -> v -> m (k, v)
> insert dict k v = pure (k, v) <|> dict
We can go ahead and insert some entries:
> main :: IO ()
> main = do
>     let dict1 = empty :: [(Int, Int)]
>     let dict2 = insert dict1 0 1
>     let dict3 = insert dict2 1 0
>     print dict3

DeltaNet

So how can we rewrite that to make it more amenable to differentiation? We don't need to! It's already general enough. We just switch to using V instead of []:
>     let dict4 = empty :: V Double (Int, Int)
>     let dict5 = insert dict4 0 1
>     let dict6 = insert dict5 1 0
>     print $ normalize dict6
What has happened is that we've replaced the dictionary update with
D' = D + k⊗v
Addition, tensor product, these things are more amenable to differentiation than appending to a list. When using the vector space monad (or applicative) the (,) operator plays a role more like tensor product. Note that we're also no longer limited to inserting basis elements, we can use any suitably typed vectors:
>     let dict7 = dict6 <|>
>          (0.5 `scale` pure (0, 0) <|> 0.5 `scale` pure (1, 1))
>     print $ normalize dict7
This is the key ingredient in the update rule in DeltaNet, described in Parallelizing Linear Transformers with the Delta Rule over Sequence Length at the start of section 2.2.

Can we do this to other algorithms

These dictionaries aren't very smart. Could we do a similar trick with a binary tree, or a hierarchical binary tree of lists? One way of approaching the binary comparisons turns this into a kind of mixture of experts model, but I'm not sure it's a good place to have MoE. Maybe there's another way?