Search This Blog

Monday, April 27, 2026

Some type constructors are tensor products


Introduction

I want to return to something I've mentioned a couple of times in the past - the fact that applying certain type constructors performs a tensor product.

First some admin stuff:

> {-# LANGUAGE DeriveFunctor #-}
> {-# LANGUAGE FlexibleInstances #-}
> {-# LANGUAGE MultiParamTypeClasses #-}
> {-# LANGUAGE UndecidableInstances #-}
> {-# LANGUAGE TypeApplications #-}
> {-# LANGUAGE KindSignatures #-}
> {-# LANGUAGE ScopedTypeVariables #-}
> {-# LANGUAGE AllowAmbiguousTypes #-}

> import Data.Proxy
> import Data.Kind (Type)

> infixr 7 ⊗

Suppose you define a type like so:

> data Complex a = C a a
>     deriving (Eq, Show, Functor)

> instance Num a => Num (Complex a) where
>     fromInteger n = C (fromInteger n) 0

>     C a b + C c d = C (a + c) (b + d)
>     C a b - C c d = C (a - c) (b - d)

>     C a b * C c d = C (a * c - b * d) (a * d + b * c)

>     negate (C a b) = C (negate a) (negate b)

>     abs    = error "abs doesn't make sense here"
>     signum = error "signum makes no sense here"

It seems straightforward. You've defined complex numbers in a way that allows a choice of base type to represent the real numbers. For example you could use Complex Float or Complex Double as representations of \(\mathbb{C}\).

In actual fact you've done quite a bit more! That code has another reading - it implements a tensor product both in the category of vector spaces, and, less trivially, in the category of algebras. So if A is a suitable algebraic structure then, if you allow me to mix code and mathematics notation,

\[ \mathtt{Complex\ A} = \mathbb{C}\otimes\mathtt{A} \]

I took this for granted when I mentioned it previously but I thought I'd look into it in a little bit more detail.


Tensor Products

I want to start from the definition of the tensor product given by its universal property, but to make that slightly less fearsome I'll use an English sketch of it.

Suppose you have a pair of vector spaces \(X\) and \(Y\) over some base field \(k\). A bilinear function \(X\times Y\rightarrow Z\) is a function that is linear in \(X\) and linear in \(Y\). Now suppose we know that at some point in the future we are going to need some bilinear function on \(X\times Y\) but don't yet know what it is. Can we make a structure, \(T\), that contains precisely the information we need so that we can compute any bilinear function we want - with the proviso that we compute these bilinear functions by applying a linear function to \(T\)? We don't want \(T\) to be lacking anything we might need to compute a future bilinear product, but we also don't want it to contain any extraneous data.

For example, imagine working with \(V\), the vector space of 3D vectors. Some examples of bilinear functions we might want are the dot product \(V\cdot V\rightarrow\mathbb{R}\) and the cross product \(V\times V\rightarrow V\). What should \(T\) look like?

We can write the dot product as \((x, y, z)\cdot(x', y', z') = xx'+yy'+zz'\). Note how it's made of products of coordinates from \((x, y, z)\) and coordinates from \((x', y', z')\). Similarly \((x, y, z)\times(x', y', z')=(yz'-zy',\ldots)\). Again, it's a linear combination of products of coordinates, one from each vector. You can prove that any bilinear product will be some linear combination of such products.

By thinking about all possible bilinear products you I hope you can see that \(T\) should be a 9-dimensional vector space and a suitable way to represent a pair of vectors \((x, y, z), (x', y', z')\) for future application of a bilinear function is as \((xx', xy', xz', yx', yy', yz', zx', zy', zz')\). Any bilinear product is a linear combination of these 9 quantities and so is given by some linear operation on \(T\). It is commonplace to arrange the 9-dimensional vector as a \(3\times 3\) matrix in which case the map from the pair is called the outer product. But it doesn't really matter as all 9-dimensional vector spaces over a given field are isomorphic.

In this case I chose to consider bilinear functions on \(V\times V\), but you can reason similarly for any pair of vector spaces \(X\) and \(Y\). When working with finite-dimensional vector spaces, the structure we need will be \(mn\)-dimensional where \(m\) is the dimension of \(X\) and \(n\) is the dimension of \(Y\). The structure is called the tensor product and is written as \(X\otimes Y\). The bilinear map from the original vectors into the tensor product is also called the tensor product and as written as a binary operator \(x\otimes y\). And once you have the tensor product, every bilinear function on the original pair of spaces can be expressed uniquely as a linear function on the tensor product.

So, for example, the dot product can be written as

\[ x\cdot y = \phi(x\otimes y) \]

where \((x, y, z)\otimes(x', y', z')=(xx',xy',\ldots zz')\) and so the linear function is \(\phi(x_0, x_1,\ldots,x_8) = x_0+x_4+x_8\).

Similarly

\[ x\times y = \psi(x\otimes y) \]

where \(\psi(x_0, x_1, \ldots, x_8)=(x_5 - x_7, x_6 - x_2, x_1 - x_3)\)


Algebras

It's a confusing use of terminology, but the term "algebra (over \(k\))" is used specifically to mean a vector space \(A\) (over \(k\)) equipped with a bilinear product \(A\times A\rightarrow A\) which is compatible with the vector space structure. And in addition I'm assuming my algebras contain a multiplicative unit element. Other people may call this a "unital algebra". I'll use the word "unital" when I want to stress that there is a unit.

An example is the algebra of complex numbers \(\mathbb{C}\) over \(\mathbb{R}\). It's a 2-dimensional vector space over \(\mathbb{R}\). We can, for example, scale complex numbers by elements of the base field. We also have properties like \((au)v = u(av)\) for \(a\in\mathbb{R}\) and \(u,v\in\mathbb{C}\). We can scale either argument of the complex product by a real and it makes no difference which we choose. See Wikipedia for all the properties an algebra must satisfy.

Vector spaces come with an addition operation and a zero but we're going to share the work out a little differently because our Num instance already has those. So our VectorSpace class is just going to have the scale operation:

> class VectorSpace k v where
>     scale :: k -> v -> v

> instance VectorSpace Double Double where
>     scale = (*)

You can think of the definition of Complex above as a container for the coordinates in a choice of basis. Because I use deriving Functor I can get the VectorSpace instance for all similar types for free:

> instance (Functor c, VectorSpace k a) => VectorSpace k (c a) where
>     scale k = fmap (scale k)

Because fmap composes through nested functors, scale descends recursively through arbitrarily nested structures like Complex (Complex Double).

And now we can concretely implement the bilinear tensor product operation in our choice of basis. It works by descending through the construction of \(x\) until it reaches its individual coordinates and then uses each one to scale \(y\). A special case of this is our 9-dimensional vector construction above: each batch of 3 coordinates is s scaling of one vector by a coordinate from the other.

> (⊗) :: (Functor c, VectorSpace k a) => c k -> a -> c a
> x ⊗ y = fmap (`scale` y) x

We're literally just recursively building a table of all products of coordinates of c k and coordinates of a.

Any bilinear function f :: U -> V -> W can now be implemented as f x y = phi (x ⊗ y) for a unique choice of phi.


Algebras too

But there's more, and this is the point of me writing this article. Algebras also have a tensor product defined on them. The underlying carrier space is the tensor product of algebras considered as vector spaces. The product structure is defined by \((x\otimes y)(x'\otimes y')=(xx')\otimes(yy')\) and linear combinations thereof. But what's neat here is that we don't have to write any more code to implement this, our Num instance is already doing the work.

We need to check that our definition of Complex satisfies this property. In fact, I want to prove it more generally for any type like Complex that has a multiplication that looks like

    C a b * C c d = C (a * c - b * d) (a * d + b * c)

ie. I'll assume we have a type F that is an instance of Num, with constructor F, and whose multiplication is constructed from a linear combination of terms of the form a * a'.

Something like:

    (F ... a ...) * (F ... a' ...) = F ... (... + a * a' + ...) ...

Note that I'm claiming

\[ \mathtt{F\ A} = \mathtt{F\ Double}\otimes\mathtt{A} \]

so I can suppose that a is in Double (or whatever we use to represent the reals).

Assuming * is such a product:

   (x ⊗ y) * (x' ⊗ y')
== fmap (`scale` y) x * fmap (`scale` y') x'
   -- definition of tensor
== fmap (`scale` y) (F ... a ...) * fmap (`scale` y') (F ... a' ...)
   -- stating our assumptions about the form of x and x'
== (F ... (scale a y) ...) * (F ... (scale a' y') ...)
   -- this is what derived fmap looks like
== F ... (... + scale a y * scale a' y' + ...) ...
   -- our assumption about the form that multiplication takes
== F ... (... + scale (a * a') (y * y') + ...) ...
   -- multiplication is bilinear all the way down
== fmap (`scale` (y * y')) (F ... (... + a * a' + ...))
   -- same fact about fmap used above
== fmap (`scale` (y * y')) (x * x')
   -- again our assumption about how multiplication is implemented
== (x * x') ⊗ (y * y')
   -- definition of tensor again

Anyway, my motivation here is that quite a while back someone (on Mastodon) I think pushed back on my claim that we have a tensor product so I thought I'd give some more detail.

I could say more. The tensor product of algebras has the nice property that you can embed the original algebras in it in a way that the two images commute with each other. In fact, if you can define the tensor product to be the initial algebra with this property. But this is too long already.

Also, I used Haskell above but it carries over straightforwardly to other languages, even C++.

> main :: IO ()
> main = do
>   print "Bye!"

Sunday, April 12, 2026

A Simple Switch Makes Code Differentiable

Introduction

One game I like to play is decomposing code and algorithms as compositions of simpler pieces even when they already seem as simple as they could be.

One example is the observation that adjoint mode automatic differentiation isn't a separate algorithm to forward mode automatic differentiation but a composition of forward mode and transposition. I talked about this in an old paper of mine Two Tricks for the Price of One and it resurfaces more recently in Jax: You Only Linearize Once.

Another example is the REDUCE algorithm used to differentiate a class of stochastic process despite making hard decisions when sampling from a distribution. It turns out this algorithm is nothing but ordinary everyday importance sampling but generalized to probabilities lying in a non-standard algebraic structure. I describe it in a blog article here and you can find out more about extending beyond the non-negative reals in a paper by Abramsky and Brandenberger. You sort of don't have to lift a finger to implement REDUCE - it just happens "for free" when you switch algebraic structure.

Here's a teeny tiny example of another "for free" method that just appears when you switch types.

The Good Old Vector Space Monad

First let me write the same "vector space" monad that I've written many times before around here:
> module Main where

> import qualified Data.Map.Strict as Map
> import Control.Applicative
> import Control.Monad

> newtype V s a = V { unV :: [(a, s)] }
>   deriving (Show)

> instance Functor (V s) where
>   fmap f (V xs) = V [ (f a, s) | (a, s) <- xs ]

> instance Num s => Applicative (V s) where
>   pure x = V [(x, 1)]
>   mf <*> mx = do { f <- mf; x <- mx; return (f x) }

> instance Num s => Monad (V s) where
>   (V xs) >>= f = V [ (b, s * t) |
>                      (a, s) <- xs, (b, t) <- unV (f a) ]

> instance Num s => Alternative (V s) where
>   empty = V []
>   (<|>) = add

> instance Num s => MonadPlus (V s) where
>   mzero = empty
>   mplus = (<|>)

> normalize :: (Ord a, Num s, Eq s) => V s a -> V s a
> normalize (V xs) = V $ Map.toList $
>                 Map.filter (/= 0) $ Map.fromListWith (+) xs

> scale :: Num s => s -> V s a -> V s a
> scale c (V xs) = V [ (a, c * s) | (a, s) <- xs ]

> add :: Num s => V s a -> V s a -> V s a
> add (V xs) (V ys) = V (xs ++ ys)

Dictionaries

One of the challenges in machine learning is to replace standard algorithms that make hard decisions with methods that are soft and squishy so they're amenable to being differentiated. An example might be building dictionaries of key-value pairs. So let's write a line of code to insert into a dictionary:
> insert' :: [(k, v)] -> k -> v -> [(k, v)]
> insert' dict k v = (k, v) : dict
It's not very clever, but it does function as a dictionary builder. Let's write it in a more general way using the fact that the list functor is applicative:
> insert :: Alternative m => m (k, v) -> k -> v -> m (k, v)
> insert dict k v = pure (k, v) <|> dict
We can go ahead and insert some entries:
> main :: IO ()
> main = do
>     let dict1 = empty :: [(Int, Int)]
>     let dict2 = insert dict1 0 1
>     let dict3 = insert dict2 1 0
>     print dict3

DeltaNet

So how can we rewrite that to make it more amenable to differentiation? We don't need to! It's already general enough. We just switch to using V instead of []:
>     let dict4 = empty :: V Double (Int, Int)
>     let dict5 = insert dict4 0 1
>     let dict6 = insert dict5 1 0
>     print $ normalize dict6
What has happened is that we've replaced the dictionary update with
D' = D + k⊗v
Addition, tensor product, these things are more amenable to differentiation than appending to a list. When using the vector space monad (or applicative) the (,) operator plays a role more like tensor product. Note that we're also no longer limited to inserting basis elements, we can use any suitably typed vectors:
>     let dict7 = dict6 <|>
>          (0.5 `scale` pure (0, 0) <|> 0.5 `scale` pure (1, 1))
>     print $ normalize dict7
This is the key ingredient in the update rule in DeltaNet, described in Parallelizing Linear Transformers with the Delta Rule over Sequence Length at the start of section 2.2.

Can we do this to other algorithms

These dictionaries aren't very smart. Could we do a similar trick with a binary tree, or a hierarchical binary tree of lists? One way of approaching the binary comparisons turns this into a kind of mixture of experts model, but I'm not sure it's a good place to have MoE. Maybe there's another way?

Blog Archive