The Jax adjoint scatter trick
TL;DR: Use jax.vjp to get the scatter adjoint of a gather operation. Or any linear operation, for that matter, but the scatter-gather use may be less obvious at first sight.
Some types of simulation / linear operator code naturally follows the pattern:
1) Get some elements of a larger tensor, the ones related to this element / node / whatever
2) Do an operation on them
3) Write them back (summing the contributions from different elements / nodes / whatevers)
Writing part 3 efficiently is a bit tricky, as the indexing might get a tad hairy. And, as it turns out, writing part 3 explicitly is completely unnecessary.
For a simple example, let’s have each node be “between” two elements of an 1D tensor.
So, as code,
x = jnp.array([0,0,2,1,0,0,3,0], jnp.float32)
gathered = gather(x)
operated = jax.vmap(op)(gathered)
result = scatter(operated)
with
def gather(v):
return jnp.stack((v[:-1], v[1:]), axis=1)
def op(elts): # Operate on single node. Returns the values to write back.
# Fictitional example: we just take the mean
# and propagate that to both elements
result = elts[0] + elts[1]
return jnp.array(
[
result, result
]
)
def scatter(result):
# ... This is tricky to write right, the edge elements
# get different contributions, you need to be concatenating zeros.
#
# Much more complex than gather().
#
# For a more complex gather (parameterized? Yikes!),
# this would be even worse.
return (
jnp.concatenate((result[:, 0], jnp.zeros(1))) +
jnp.concatenate((jnp.zeros(1), result[:, 1]))
)
But here’s the good news: we don’t have to write the scatter function explicitly. Both gather and scatter are linear operations (representable as matrices), and they are adjoint to each other.
All we need to do is:
gathered, scatter = jax.vjp(gather, x)
operated = jax.vmap(op)(gathered)
result = scatter(operated)
This gives correctly gives result
as [0. , 1. , 2.5, 2. , 0.5, 1.5, 3. , 1.5]
- same as the explicit code before.
See how scatter was defined here inline by jax.vjp
. Why does this work? A linear operation is its own derivative; a matrix is its own jacobian, and vjp is simply “multiply the jacobian from the output side with this vector. So exactly the right op for this occasion.
Now, we can make gather a lot more complex without any additional effort on the scatter side.
Defining matrix adjoints via vjp is a powerful thing; for example, the matfree library uses this to a great advantage - to specify a non-symmetric linear operator for a krylov subspace method, you only need to provide the matrix-vector product one way.
What’s my specific use for this right now? FEM code. But that’s another story, for another time - or not, depending.
Enjoy Reading This Article?
Here are some more articles you might like to read next: