Einsum Jacobian

Einsum Jacobian

You can write NN’s, CNN’s, and even transformers in terms of einsum and activation functions. You could build an autograd library around einsum and activations and it could handle modern architectures. That sounds funny, so I want to give it a try. When you look at these its not immediately clear how to differentiate them.

z = einsum('abc,dck->adbk', x, w)
...
dx = ???

The process

Its three steps:

  1. initialize

  2. put values from in

  3. if necessary, collapse/reshape

Step 1

In general, something like can be viewed as a function

E : \mathbb{R}^{\text{x.shape}}\rightarrow \mathbb{R}^{\text{z.shape}}

but instead we are going to view the output as being zero-embedded into

E : \mathbb{R}^{\text{x.shape}}\rightarrow \mathbb{R}^{\text{z.shape}} \hookrightarrow \mathbb{R}^{(\text{x.shape})\times(\text{w.shape})} Thus our jacobian will be of shape (\text{x.shape} \times \text{w.shape}) \times \text{x.shape} where

j_{ijklmn} = \frac{\partial z_{ijkl}}{\partial x_{mn}} (if X and W are both 2-dimensional matricies. For 3, 4-d you’d need more subscripts). This probably seems goofy, but it has some nice organizational properties

Step 2

In this embedding no dimensions are summed. So every element of Z must either be 0, by embedding, or something from X multiplied by something from W.

z_{ijkl} = \begin{cases} 0 \\ x_{ij}w_{kl} \end{cases}

Ok, where is Z defined (not embedded to zero)? This is probably easiest to show by example

  1. Consider , in our embedding we get the tensor z_{ijkl} = \begin{cases} x_{ij}w_{kl} & \text{if } \ j = k \\ 0 & \text{else } \\ \end{cases}

  2. Consider , in our embedding we get the tensor z_{ijkl} = \begin{cases} x_{ij}w_{kl} & \text{if }\ i = k, j = l \\ 0 & \text{else} \\ \end{cases}

  3. Consider , in our embedding we get the tensor z_{ijklm} = \begin{cases} x_{ijk}w_{lm} & \text{if } \ i = l \\ 0 & \text{else} \\ \end{cases}

  4. Consider , in our embedding we get the tensor z_{ijklm} = \begin{cases} x_{ijk}w_{lm} \\ \end{cases}

Z is only defined along the axes that get multiplied together.

Looking at these its also clear that z_{ijkl} is independent of x_{mn} for (ij)\neq(mn). i.e. \frac{\partial z_{ijkl}}{\partial x_{mn}} = \begin{cases} w_{kl} & \text{if } \ i = m, j = n \\ 0 & \text{else} \end{cases}

So in general our jacobian should be organized like this j_{ijklmn} = \begin{cases} w_{kl} & \text{if } \ i=m,j=n, \ z_{ijkl} \text{ is defined}\\ 0 & \text{else} \end{cases}

For examples 1-4, we populate the jacobian as follows

einsum('ijjkij->ijk', jacobian)[:] = w
# 2. input pattern 'ij,ij->ij'
einsum('ijijij->ij', jacobian)[:] = w
# 3. input pattern 'bij,bk->bij'
einsum('bijbkbij->ijbk', jacobian)[:] = w
# 4. input pattern 'bij,kl->bijkl'
einsum('bijklbij->bijkl', jacobian)[:] = w
def jacobian_diagonal(ptrn):
    op1, op2 = ptrn.split('->')[0].split(',')
    start = (op1 + op2) + op1
    end = "".join([c for c in op1 if c not in op2]) + op2
return f"{start}->{end}"

We are broadcasting into , the shape of corresponds with , so has to be the last dimensions of the diagonal in order for the broadcasting to work.

And thats it for step 2. Our jacobian has the correct values in the correct positions. It just needs to be reshaped in accordance with the original einsum.

Step 3

Whatever axis summing or swapping happened in the original einsum now needs to happen in our jacobian. Remember our jacobian is \big(\text{x.shape} \times \text{w.shape} \big) \times \text{x.shape}. So we need to sum/swap the first \big(\text{x.shape} \times \text{w.shape} \big) dims the same way as the original while leaving the trailing \text{x.shape} dimensions alone. For examples 1-4 this would be

jacobian = einsum('ijjkab->ikab', jacobian) 
# 2. input pattern 'ij,ij->ij'
jacobian = einsum('ijijab->ijab', jacobian) 
# 3. input pattern 'bij,bk->bij'
jacobian = einsum('bijbkbab->ijab', jacobian) 
# 4. input pattern 'bij,kl->bijkl'
jacobian = einsum('bijklabc->bijklabc', jacobian)

The subscript names dont matter as long as you sum the right stuff. For 1 we could do . We can write a simple function that generates this einsum pattern by using elipses.

def organize_jacobian(ptrn):
    start, end = ptrn.replace(',', '').split('->')
    return f"{start}...->{end}..."

organize_jacobian('ij,jk->ik')
# returns ijjk...->ik...
organize_jacobian('ij,ij->')
# returns bijbik...->bij...

Final Product

lets write a little function that does the whole thing.

from torch import einsum
from torch.autograd.functional import jacobian

def jacobian_diag_ptrn(ptrn):
    op1, op2 = ptrn.split('->')[0].split(',')
    start = (op1 + op2) + op1
    end = "".join([c for c in op1 if c not in op2]) + op2
    return f"{start}->{end}"

def organize_jacobian_ptrn(ptrn):
    start, end = ptrn.replace(',', '').split('->')
    return f"{start}...->{end}..."

def get_ptrns(ptrn):
    return jacobian_diag_ptrn(ptrn), organize_jacobian_ptrn(ptrn)

def einsum_jacobian(ptrn, x, w):
    j = torch.zeros(*x.shape, *w.shape, *x.shape)
    diag_ptrn, org_ptrn = get_ptrns(ptrn)
    einsum(diag_ptrn, j)[:] = w
    return einsum(org_ptrn, j)

now test

def sanity(ptrn, x, w):
    f = lambda x: einsum(ptrn, x, w)
    j = jacobian(f, x)
    manual_j = einsum_jacobian(ptrn, x, w)
    return torch.allclose(j, manual_j)

ptrn1 = 'ij,jk->ik' # mm
x1, w1 = torch.randn(2, 3), torch.randn(3, 4)

ptrn2 = 'bij,bjk->bik' # bmm
x2, w2 = torch.randn(5, 3, 8), torch.randn(5, 8, 9)

ptrn3 = 'bchwkt,fckt->bfhw' # convolution
x3, w3 = torch.randn(1, 3, 16, 16, 2, 2), torch.randn(2, 3, 2, 2)

ptrn4 = 'abcd,efd->' # ???
x4, w4 = torch.randn(1, 2, 3, 4), torch.randn(9, 10, 4)

stuff = [(ptrn1, x1, w1), (ptrn2, x2, w2), (ptrn3, x3, w3), (ptrn4, x4, w4)]
results = [sanity(*thing) for thing in stuff] 

results
# [True, True, True, True]

And thats how you get the jacobian of einsum. Its not something youd ever use since its so enourmous you run out of RAM with anything more than toy examples, but still kinda neat.