Einsum Jacobian
You can write NN’s, CNN’s, and even transformers in terms of einsum and activation functions. You could build an autograd library around einsum and activations and it could handle modern architectures. That sounds funny, so I want to give it a try. When you look at these its not immediately clear how to differentiate them.
z = einsum('abc,dck->adbk', x, w)
...
dx = ???The process
Its three steps:
initialize
put values from in
if necessary, collapse/reshape
Step 1
In general, something like can be viewed as a function
E : \mathbb{R}^{\text{x.shape}}\rightarrow \mathbb{R}^{\text{z.shape}}
but instead we are going to view the output as being zero-embedded into
E : \mathbb{R}^{\text{x.shape}}\rightarrow \mathbb{R}^{\text{z.shape}} \hookrightarrow \mathbb{R}^{(\text{x.shape})\times(\text{w.shape})} Thus our jacobian will be of shape (\text{x.shape} \times \text{w.shape}) \times \text{x.shape} where
j_{ijklmn} = \frac{\partial z_{ijkl}}{\partial x_{mn}} (if X and W are both 2-dimensional matricies. For 3, 4-d you’d need more subscripts). This probably seems goofy, but it has some nice organizational properties
Step 2
In this embedding no dimensions are summed. So every element of Z must either be 0, by embedding, or something from X multiplied by something from W.
z_{ijkl} = \begin{cases} 0 \\ x_{ij}w_{kl} \end{cases}
Ok, where is Z defined (not embedded to zero)? This is probably easiest to show by example
Consider , in our embedding we get the tensor z_{ijkl} = \begin{cases} x_{ij}w_{kl} & \text{if } \ j = k \\ 0 & \text{else } \\ \end{cases}
Consider , in our embedding we get the tensor z_{ijkl} = \begin{cases} x_{ij}w_{kl} & \text{if }\ i = k, j = l \\ 0 & \text{else} \\ \end{cases}
Consider , in our embedding we get the tensor z_{ijklm} = \begin{cases} x_{ijk}w_{lm} & \text{if } \ i = l \\ 0 & \text{else} \\ \end{cases}
Consider , in our embedding we get the tensor z_{ijklm} = \begin{cases} x_{ijk}w_{lm} \\ \end{cases}
Z is only defined along the axes that get multiplied together.
Looking at these its also clear that z_{ijkl} is independent of x_{mn} for (ij)\neq(mn). i.e. \frac{\partial z_{ijkl}}{\partial x_{mn}} = \begin{cases} w_{kl} & \text{if } \ i = m, j = n \\ 0 & \text{else} \end{cases}
So in general our jacobian should be organized like this j_{ijklmn} = \begin{cases} w_{kl} & \text{if } \ i=m,j=n, \ z_{ijkl} \text{ is defined}\\ 0 & \text{else} \end{cases}
For examples 1-4, we populate the jacobian as follows
einsum('ijjkij->ijk', jacobian)[:] = w
# 2. input pattern 'ij,ij->ij'
einsum('ijijij->ij', jacobian)[:] = w
# 3. input pattern 'bij,bk->bij'
einsum('bijbkbij->ijbk', jacobian)[:] = w
# 4. input pattern 'bij,kl->bijkl'
einsum('bijklbij->bijkl', jacobian)[:] = wdef jacobian_diagonal(ptrn):
op1, op2 = ptrn.split('->')[0].split(',')
start = (op1 + op2) + op1
end = "".join([c for c in op1 if c not in op2]) + op2
return f"{start}->{end}"We are broadcasting into , the shape of corresponds with , so has to
be the last dimensions of the diagonal in order for the
broadcasting to work.
And thats it for step 2. Our jacobian has the correct values in the
correct positions. It just needs to be reshaped in accordance with the
original einsum.
Step 3
Whatever axis summing or swapping happened in the original einsum now needs to happen in our jacobian. Remember our jacobian is \big(\text{x.shape} \times \text{w.shape} \big) \times \text{x.shape}. So we need to sum/swap the first \big(\text{x.shape} \times \text{w.shape} \big) dims the same way as the original while leaving the trailing \text{x.shape} dimensions alone. For examples 1-4 this would be
jacobian = einsum('ijjkab->ikab', jacobian)
# 2. input pattern 'ij,ij->ij'
jacobian = einsum('ijijab->ijab', jacobian)
# 3. input pattern 'bij,bk->bij'
jacobian = einsum('bijbkbab->ijab', jacobian)
# 4. input pattern 'bij,kl->bijkl'
jacobian = einsum('bijklabc->bijklabc', jacobian)The subscript names dont matter as long as you sum the right stuff. For 1 we could do . We can write a simple function that generates this einsum pattern by using elipses.
def organize_jacobian(ptrn):
start, end = ptrn.replace(',', '').split('->')
return f"{start}...->{end}..."
organize_jacobian('ij,jk->ik')
# returns ijjk...->ik...
organize_jacobian('ij,ij->')
# returns bijbik...->bij...Final Product
lets write a little function that does the whole thing.
from torch import einsum
from torch.autograd.functional import jacobian
def jacobian_diag_ptrn(ptrn):
op1, op2 = ptrn.split('->')[0].split(',')
start = (op1 + op2) + op1
end = "".join([c for c in op1 if c not in op2]) + op2
return f"{start}->{end}"
def organize_jacobian_ptrn(ptrn):
start, end = ptrn.replace(',', '').split('->')
return f"{start}...->{end}..."
def get_ptrns(ptrn):
return jacobian_diag_ptrn(ptrn), organize_jacobian_ptrn(ptrn)
def einsum_jacobian(ptrn, x, w):
j = torch.zeros(*x.shape, *w.shape, *x.shape)
diag_ptrn, org_ptrn = get_ptrns(ptrn)
einsum(diag_ptrn, j)[:] = w
return einsum(org_ptrn, j)now test
def sanity(ptrn, x, w):
f = lambda x: einsum(ptrn, x, w)
j = jacobian(f, x)
manual_j = einsum_jacobian(ptrn, x, w)
return torch.allclose(j, manual_j)
ptrn1 = 'ij,jk->ik' # mm
x1, w1 = torch.randn(2, 3), torch.randn(3, 4)
ptrn2 = 'bij,bjk->bik' # bmm
x2, w2 = torch.randn(5, 3, 8), torch.randn(5, 8, 9)
ptrn3 = 'bchwkt,fckt->bfhw' # convolution
x3, w3 = torch.randn(1, 3, 16, 16, 2, 2), torch.randn(2, 3, 2, 2)
ptrn4 = 'abcd,efd->' # ???
x4, w4 = torch.randn(1, 2, 3, 4), torch.randn(9, 10, 4)
stuff = [(ptrn1, x1, w1), (ptrn2, x2, w2), (ptrn3, x3, w3), (ptrn4, x4, w4)]
results = [sanity(*thing) for thing in stuff]
results
# [True, True, True, True]And thats how you get the jacobian of einsum. Its not something youd ever use since its so enourmous you run out of RAM with anything more than toy examples, but still kinda neat.