A relaxation technique
Introduction
Sometimes you want to differentiate the expected value of something. I've written about some tools that can help with this. For example you can use Automatic Differentiation for the derivative part and probability monads for the expectation. But the probability monad I described in that article computes the complete probability distribution for your problem. Frequently this is intractably large. Instead people often use Monte Carlo methods. They'll compute the "something" many times, substituting pseudo-random numbers for the random variables, and then average the results. This provides an estimate of the expected value and is ubiquitous in many branches of computer science. For example it's the basis of ray-tracing and path-tracing algorithms in 3D rendering, and plays a major role in machine learning when used in the form of stochastic gradient descent.
But there's a catch.
Suppose we want to compute where each of the belong to the Bernoulli distribution .
I.e. each has a probability of being 1 and probability of being 0.
If we compute this using a Monte Carlo approach we'll repeatedly generate pseudo-random numbers for each of the .
Each one will be 0 or 1.
This means that our estimate depends on via subexpressions that can't meaningfully be differentiated with respect to .
So how can we use automatic differentiation with the Monte Carlo method?
I'm proposing an approach that may or may not already be in the literature.
Whether it is or not, I think it's fun to get there by combining many of the things I've previously talked about here, such as free monads, negative probabilities and automatic differentiation.
I'm going to assume you're familiar with using dual numbers to compute derivatives as I've written about this before and wikipedia has the basics.
A probability monad
I want to play with a number of different approaches to using monads with probability theory.
Rather than define lots of monads I think that the easiest thing is to simply work with one free monad and then provide different interpreters for it.
First some imports:
> import Control.Monad > import qualified System.Random as R > import qualified Data.Map.Strict as MI'm going to use a minimal free monad that effectively gives us a DSL with a new function that allows us to talk about random Bernoulli variables:
> data Random p a = Pure a | Bernoulli p (Int -> Random p a)The idea is that Pure a represents the value a and Bernoulli p f is used to say "if we had a random value x, f x is the value we're interested in". The Random type isn't going to do anything other than represent these kinds of expressions. There's no implication that we actually have a random value for x yet.
> instance Functor (Random p) where > fmap f (Pure a) = Pure (f a) > fmap f (Bernoulli p g) = Bernoulli p (fmap f . g)We'll use bernoulli p to represent a random Bernoulli variable drawn from .
> instance Applicative (Random p) where > pure = return > (<*>) = ap
> instance Monad (Random p) where > return = Pure > Pure a >>= f = f a > Bernoulli p g >>= f = Bernoulli p (\x -> g x >>= f)
> bernoulli :: p -> Random p Int > bernoulli p = Bernoulli p returnSo let's write our first random expression:
> test1 :: Random Float Float > test1 = do > xs <- replicateM 4 (bernoulli 0.75) > return $ fromIntegral $ sum xsIt sums 4 Bernoulli random variables from and converts the result to a Float. The expected value is 3.
We don't yet have a way to do anything with this expression.
So let's write an interpreter that can substitute pseudo-random values for each occurrence of bernoulli p:
It's essentially interpreting our free monad as a state monad where the state is the random number seed:
> interpret1 :: (Ord p, R.Random p, R.RandomGen g) => Random p a -> g -> (a, g) > interpret1 (Pure a) seed = (a, seed) > interpret1 (Bernoulli prob f) seed = > let (r, seed') = R.random seed > b = if r <= prob then 1 else 0 > in interpret1 (f b) seed'You can use the expression R.getStdRandom (interpret1 test1) if you want to generate some random samples for yourself.
We're interested in the expected value, so here's a function to compute that:
> expect1 :: (Fractional p, Ord p, R.Random p, R.RandomGen g) => Random p p -> Int -> g -> (p, g) > expect1 r n g = > let (x, g') = sum1 0 r n g > in (x/fromIntegral n, g')You can test it out with R.getStdRandom (expect1 test1 1000). You should get values around 3.
> sum1 :: (Ord p, Num p, R.Random p, R.RandomGen g) => p -> Random p p -> Int -> g -> (p, g) > sum1 t r 0 g = (t, g) > sum1 t r n g = > let (a, g') = interpret1 r g > in sum1 (t+a) r (n-1) g'
We can try completely different semantics for Random.
This time we compute the entire probability distribution:
> interpret2 :: (Num p) => Random p a -> [(a, p)] > interpret2 (Pure a) = [(a, 1)] > interpret2 (Bernoulli p f) = > scale p (interpret2 (f 1)) ++ scale (1-p) (interpret2 (f 0))You can try it with interpret2 test1.
> scale :: Num p => p -> [(a, p)] -> [(a, p)] > scale s = map (\(a, p) -> (a, s*p))
Unfortunately, as it stands it doesn't collect together multiple occurrences of the same value.
We can do that with this function:
> collect :: (Ord a, Num b) => [(a, b)] -> [(a, b)] > collect = M.toList . M.fromListWith (+)And now you can use collect (interpret2 test1).
Let's compute some expected values:
> expect2 :: (Num p) => Random p p -> p > expect2 r = sum $ map (uncurry (*)) (interpret2 r)The value of expect2 test1 should be exactly 3. One nice thing about interpret2 is that it is differentiable with respect to the Bernoulli parameter when this is meaningful. Unfortunately it has one very big catch: the value of interpret2 can be a very long list. Even a small simulation can results in lists too big to store in the known universe. But interpret1 doesn't produce differentiable results. Is there something in-between these two interpreters?
Importance sampling
Frequently in Monte Carlo sampling it isn't convenient to sample from the distribution you want. For example it might be intractably hard to do so, or you might have proven that the resulting estimate has a high variance. So instead you can sample from a different, but possibly related distribution. This is known as importance sampling. Whenever you do this you must keep track of how "wrong" your probability was and patch up your expectation estimate at the end. For example, suppose a coin comes up heads 3/4 of the time. Instead of simulating a coin toss that comes up 3/4 of the time you could simulate one that comes up heads half of the time. Suppose at one point in the simulation it does come up heads. Then you used a probability of 1/2 when you should have used 3/4. So when you compute the expectation you need to scale the contribution from this sample by (3/4)/(1/2) = 3/2. You need so scale appropriately for every random variable used. A straightforward way to see this for the case of a single Bernoulli variable is to note that
.We've replaced probabilities and with and but we had to scale appropriately in each of the cases and to keep the final value the same. I'm going to call the scale value the importance. If we generate random numbers in a row we need to multiply all of the importance values that we generate. This is a perfect job for the Writer monad using the Product monoid. (See Eric Kidd's paper for some discussion about the connection between Writer and importance sampling.) However I'm just going to write an explicit interpreter for our free monad to make it clear what's going where.
This interpreter is going to take an additional argument as input.
It'll be a rule saying what probability we should sample with when handling a variable drawn from .
The probability should be a real number in the interval .
> interpret3 :: (Fractional p, R.RandomGen g) => > (p -> Float) -> Random p a -> g -> ((a, p), g) > interpret3 rule (Pure a) g = ((a, 1), g) > interpret3 rule (Bernoulli p f) g = > let (r, g') = R.random g > prob = rule p > (b, i) = if (r :: Float) <= prob > then (1, p/realToFrac prob) > else (0, (1-p)/realToFrac (1-prob)) > ((a, i'), g'') = interpret3 rule (f b) g' > in ((a, i*i'), g'')Here's the accompanying code for the expectation:
> expect3 :: (Fractional p, R.RandomGen g) => > (p -> Float) -> Random p p -> Int -> g -> (p, g) > expect3 rule r n g = > let (x, g') = sum3 rule 0 r n g > in (x/fromIntegral n, g')For example, you can estimate the expectation of test1 using unbiased coin tosses by evaluating R.getStdRandom (expect3 (const 0.5) test1 1000).
> sum3 :: (Fractional p, R.RandomGen g) => > (p -> Float) -> p -> Random p p -> Int -> g -> (p, g) > sum3 rule t r 0 g = (t, g) > sum3 rule t r n g = > let ((a, imp), g') = interpret3 rule r g > in sum3 rule (t+a*imp) r (n-1) g'
Generalising probability
Did you notice I made my code slightly more general than seems to be needed? Although I use probabilities of type Float to generate my Bernoulli samples, the argument to the function bernoulli can be of a more general type. This means that we can use importance sampling to compute expected values for generalised measures that take values in a more general algebraic structure than the interval [0,1]. For example, we could use negative probabilities. An Operational Interpretation of Negative Probabilities and No-Signalling Models by Adamsky and Brandenberger give a way to interpret expressions involving negative probabilities. We can implement it using interpret3 and the rule \p -> abs p/(abs p+abs (1-p)). Note that it is guaranteed to produce values in the range [0,1] (if you start with dual numbers with real parts that are ordinary probabilities) and reproduces the usual behaviour when given ordinary probabilities.
Here's a simple expression using a sample from "":
> test2 = do > a <- bernoulli 2 > return $ if a==1 then 2.0 else 1.0It's expected value is 3. We can get this exactly using expect2 test2. For a Monte Carlo estimate use
R.getStdRandom (expect3 (\back p -> abs p/(abs p+abs (1-p))) test2 1000)Note that estimates involving negative probabilities can have quite high variances so try a few times until you get something close to 3 :-)
We don't have to stick with real numbers.
We can use this approach to estimate with complex probabilities (aka quantum mechanics) or other algebraic structures.
Discrete yet differentiable
And now comes the trick: automatic differentiation uses the algebra of dual numbers. It's not obvious at all what a probability like means when is infinitesimal. However, we can use interpret3 to give it meaningful semantics.
Let'd define the duals in the usual way first:
> data Dual a = D { real :: a, infinitesimal :: a }Now we can use the rule real to give as a real-valued probability from a dual number. The function expect3 will push the infinitesimal part into the importance value so it doesn't get forgotten about. And now expect3 gives us an estimate that is differentiable despite the fact that our random variables are discrete.
> instance (Ord a, Num a) => Num (Dual a) where > D a b + D a' b' = D (a+a') (b+b') > D a b * D a' b' = D (a*a') (a*b'+a'*b) > negate (D a b) = D (negate a) (negate b) > abs (D a b) = if a > 0 then D a b else D (-a) (-b) > signum (D a b) = D (signum a) 0 > fromInteger a = D (fromInteger a) 0
> instance (Ord a, Fractional a) => Fractional (Dual a) where > fromRational a = D (fromRational a) 0 > recip (D a b) = let ia = 1/a in D ia (-b*ia*ia)
> instance Show a => Show (Dual a) where > show (D a b) = show a ++ "[" ++ show b ++ "]"
Let's try an expression:
> test3 p = do > a <- bernoulli p > b <- bernoulli p > return $ if a == 1 && b == 1 then 1.0 else 0.0The expected value is and the derivative is . We can evaluate at with expect2 (test3 (D 0.5 1)). And we can estimate it with
R.getStdRandom (expect3 real (test4 (D 0.5 1)) 1000)What's neat is that we can parameterise our distributions in a more complex way and we can freely mix with conventional expressions in our parameter. Here's an example:
> test4 p = do > a <- bernoulli p > b <- bernoulli (p*p) > return $ p*fromIntegral a*fromIntegral bTry evaluating expect2 (test4 (D 0.5 1)) and
R.getStdRandom (expect3 real (test4 (D 0.5 1)) 1000)I've collected the above examples together here:
> main = do > print =<< R.getStdRandom (interpret1 test1) > print $ collect $ interpret2 test1 > print =<< R.getStdRandom (expect1 test1 1000) > print (expect2 test1) > print =<< R.getStdRandom (expect3 id test1 1000) > print =<< R.getStdRandom (expect3 (const 0.5) test1 1000) > print "---" > print $ expect2 test2 > print =<< R.getStdRandom (expect3 (\p -> abs p/(abs p+abs (1-p))) test2 1000) > print "---" > print $ expect2 (test3 (D 0.5 1)) > print =<< R.getStdRandom (expect3 real (test3 (D 0.5 1)) 1000) > print "---" > print $ expect2 (test4 (D 0.5 1)) > print =<< R.getStdRandom (expect3 real (test4 (D 0.5 1)) 1000)
What just happened?
You can think of a dual number as a real number that has been infinitesimally slightly deformed. To differentiate something we need to deform something. But we can't deform 0 or 1 and have them stay 0 or 1. So the trick is to embed probability sampling in something "bigger", namely importance sampling, where samples carry around an importance value. This bigger thing does allow infinitesimal deformations. And that allows differentiation. This process of turning something discrete into something continuously "deformable" is generally called relaxation.
Implementation details
I've made no attempt to make my code fast. However I don't think there's anything about this approach that's incompatible with performance. There's no need to use a monad. Instead you can track the importance value through your code by hand and implement everything in C. Additionally, I've previously written about the fact that for any trick involving forward mode AD there is another corresponding trick you can use with reverse mode AD. So this method is perfectly comptible with back-propagation. Note also that the dual number importances always have real part 1 which means you don't actually need to store them.
The bad news is that the derivative estimate can sometimes have a high variance.
Nonetheless, I've used it successfully for some toy optimisation problems.
I don't know if this approach is effective for industrial strength problems.
Your mileage may vary :-)
Alternatives
Sometimes you may find that it is acceptable to deform the samples from your discrete distribution. In that case you can use the concrete relaxation.
Continuous variables
The above method can be adapted to work with continuous variables. There is a non-trivial step which I'll leave as an exercise but I've tested it in some Python code. I think it reproduces a standard technique and it gives an alternative way to think about that trick. That article is also useful for ways to deal with the variance issues. Note also that importance sampling is normally used itself as a variance reduction technique. So there are probably helpful ways to modify the rule argument to interpret3 to simultaneously estimate derivatives and keep the variance low.
Personal note
I've thought about this problem a couple of times over the years. Each time I've ended up thinking "there's no easy way to extend AD to work with random variables so don't waste any more time thinking about it". So don't listen to anything I say. Also, I like that this method sort of comes "for free" once you combine methods I've described previously.
Acknowledgements
I think it was Eric Kidd's paper on building probability monads that first brought to my attention that there are many kinds of semantics you can use with probability theory - i.e. there are many interpreters you can write for the Random monad. I think there is an interesting design space worth exploring here.
Answer to exercise
I set the continuous case as an exercise above. Here is a solution.
Suppose you're sampling from a distribution parameterised by with pdf .
To compute the derivative with respect to you need to consider sampling from where is an infinitesimal.
.As we don't know how to sample from a pdf with infinitesimals in it, we instead sample using as usual, but use an importance of
The coefficient of the gives the derivative. So we need to compute the expectation, scaling each sample with this coefficient. In other words, to estimate we use
where the are drawn from the original distribution. This is exactly what is described at Shakir Mohamed's blog.
Final word
I managed to find the method in the literature. It's part of the REINFORCE method. For example, see equation (5) there.
7 Comments:
Nice and clear! Your method is new to me even though you found it in the literature in the end.
Regarding how to make a sampler for a continuous distribution (automatically) differentiable with respect to its parameters: your "answer to exercise" is the most direct generalization of your method, but there's also the "reparameterisation trick" described by an adjacent blog post by Shakir Mohamed.
Thanks!
Wonderful code!
Sorry to go off-topic for a moment here, Dan. I'm wanting to write a small interpreter and in looking around the web, saw your C code for SASL on Github and was instantly smitten - it's by far the nicest C code I've seen!
Would you mind if I used the lexing and parsing code from that?
I'm a bit of a "public domain" weenie, so I'm keen to release the interpreter as P.D. - would that be ok with you? I would certainly give acknowledgement to you as the source of the code - no problem there!
Thanks for your time - looking forward to hearing from you! Bye for now.
- Andy
jd43 you can do what you like with my code. I'm happy for any part you use to become PD.
Hi again Dan -
Great! Thanks very much for that!
I really like the use of lists in your C code for the SASL compiler. Very clean and elegant (as is SASL itself). It's a pity that more use hasn't been made of SASL around the world. That's also the case for Hope, another Haskell precursor language.
Both of them pack a lot of Haskell's elegance into a small package.
I'm keen to try the list approach with lexing and parsing of other languages. I might even see if it's possible to do a "mini-Parsec" using SASL as the base language. We'll see.... :)
Thanks again - keep up the good work with your blog here!
Bye for now -
- Andy
Hi Dan,
Is there an easy way to get a literate haskell version of posts like this? Or just copy and past?
Thanks,
Andrew
Andrew, you should be able to literally copy and paste the text above into a text editor. Tell me what makes it fail if it does.
Hi Dan, Sorry for the delay. A few things give a problem, but they're not because of anything to do with literate Haskell, where the only problem is losing links and inline equations...
First, I've noticed that here and elsewhere, you must be using a fairly old version of GHC without the Functor-Applicative-Monad proposal so most of your monad instances have to be rejiggered, here and elsewhere (but that was actually a useful exercise to help me understand the hierarchy there!).
In various other posts errors do crop up, but usually the error message gives a guide to correcting it.
In this particular post, I do get an error that I can't seem to fix, in the definition of interpret1 (alas, blogger won't let me use "pre", so the formatting is terrible...
<interactive>:4:9: error:
• Could not deduce (R.Random t0)
from the context: (Ord p, R.Random p, R.RandomGen g)
bound by the type signature for:
interpret1 :: (Ord p, R.Random p, R.RandomGen g) => Random p a -> g -> (a, g)
at :1:1-77
The type variable ‘t0’ is ambiguous
• When checking that the inferred type
seed' :: forall t. R.Random t => g
is as general as its inferred signature
seed' :: g
In the expression:
let
(r, seed') = R.random seed
b = if r <= prob then 1 else 0
in interpret1 (f b) seed'
In an equation for ‘interpret1’:
interpret1 (Bernoulli prob f) seed
= let
(r, seed') = R.random seed
b = ...
in interpret1 (f b) seed'
Post a Comment
<< Home