A Simple Switch Makes Code Differentiable
Introduction
One game I like to play is decomposing code and algorithms as compositions of simpler pieces even when they already seem as simple as they could be.One example is the observation that adjoint mode automatic differentiation isn't a separate algorithm to forward mode automatic differentiation but a composition of forward mode and transposition. I talked about this in an old paper of mine Two Tricks for the Price of One and it resurfaces more recently in Jax: You Only Linearize Once.
Another example is the REDUCE algorithm used to differentiate a class of stochastic process despite making hard decisions when sampling from a distribution. It turns out this algorithm is nothing but ordinary everyday importance sampling but generalized to probabilities lying in a non-standard algebraic structure. I describe it in a blog article here and you can find out more about extending beyond the non-negative reals in a paper by Abramsky and Brandenberger. You sort of don't have to lift a finger to implement REDUCE - it just happens "for free" when you switch algebraic structure.
Here's a teeny tiny example of another "for free" method that just appears when you switch types.
The Good Old Vector Space Monad
First let me write the same "vector space" monad that I've written many times before around here:
> module Main where
> import qualified Data.Map.Strict as Map
> import Control.Applicative
> import Control.Monad
> newtype V s a = V { unV :: [(a, s)] }
> deriving (Show)
> instance Functor (V s) where
> fmap f (V xs) = V [ (f a, s) | (a, s) <- xs ]
> instance Num s => Applicative (V s) where
> pure x = V [(x, 1)]
> mf <*> mx = do { f <- mf; x <- mx; return (f x) }
> instance Num s => Monad (V s) where
> (V xs) >>= f = V [ (b, s * t) |
> (a, s) <- xs, (b, t) <- unV (f a) ]
> instance Num s => Alternative (V s) where
> empty = V []
> (<|>) = add
> instance Num s => MonadPlus (V s) where
> mzero = empty
> mplus = (<|>)
> normalize :: (Ord a, Num s, Eq s) => V s a -> V s a
> normalize (V xs) = V $ Map.toList $
> Map.filter (/= 0) $ Map.fromListWith (+) xs
> scale :: Num s => s -> V s a -> V s a
> scale c (V xs) = V [ (a, c * s) | (a, s) <- xs ]
> add :: Num s => V s a -> V s a -> V s a
> add (V xs) (V ys) = V (xs ++ ys)
Dictionaries
One of the challenges in machine learning is to replace standard algorithms that make hard decisions with methods that are soft and squishy so they're amenable to being differentiated. An example might be building dictionaries of key-value pairs. So let's write a line of code to insert into a dictionary:> insert' :: [(k, v)] -> k -> v -> [(k, v)] > insert' dict k v = (k, v) : dictIt's not very clever, but it does function as a dictionary builder. Let's write it in a more general way using the fact that the list functor is applicative:
> insert :: Alternative m => m (k, v) -> k -> v -> m (k, v) > insert dict k v = pure (k, v) <|> dictWe can go ahead and insert some entries:
> main :: IO () > main = do > let dict1 = empty :: [(Int, Int)] > let dict2 = insert dict1 0 1 > let dict3 = insert dict2 1 0 > print dict3
DeltaNet
So how can we rewrite that to make it more amenable to differentiation? We don't need to! It's already general enough. We just switch to usingV instead of []:
> let dict4 = empty :: V Double (Int, Int) > let dict5 = insert dict4 0 1 > let dict6 = insert dict5 1 0 > print $ normalize dict6What has happened is that we've replaced the dictionary update with
D' = D + k⊗vAddition, tensor product, these things are more amenable to differentiation than appending to a list. When using the vector space monad (or applicative) the
(,) operator plays a role more like tensor product.
Note that we're also no longer limited to inserting basis elements, we can use any suitably typed vectors:
> let dict7 = dict6 <|> > (0.5 `scale` pure (0, 0) <|> 0.5 `scale` pure (1, 1)) > print $ normalize dict7This is the key ingredient in the update rule in DeltaNet, described in Parallelizing Linear Transformers with the Delta Rule over Sequence Length at the start of section 2.2.
0 Comments:
Post a Comment
<< Home