Why you should write Jax functions without broadcasting
TL;DR: Obscure style tip for Jax. Relevant only for people who use Jax and use broadcasting regularly. (Note: not an april fools joke despite the publication date)
I’ve been using Jax a long time in personal (and some work) projects (footnote) Recently, I realized that I had been using jax in a suboptimal way – I was using it too much like numpy.
Old way: make everything broadcastable
In numpy (and the other languages I mentioned) you want to code your functions so that broadcasting works. So you have to write code so that it only uses the last few dimensions and allows any number of leading dimensions. The reason for this is performance: if you write a function for fewer dimensions and try to broadcast it after the fact with e.g. numpy.vectorize
, you will lose most of the performance benefits of numpy – what you just wrote is a python loop.
This makes it easy to end up with certain classes of bugs
- forgetting to add “…” to a slice will bite you in unexpected ways much later
- same with specifying axes to aggregate over; they need to be negative, not positive.
- also, forget about using the matrix multiplication operator
@
, with matrices and vectors you will get the wrong indices. - every einsum invocation must do extra footwork with the ellipsis.
&c &c
Let’s look at the following example:
def apply_rigid(m, v):
"""Apply a rigid 4x4 transform matrix to a 3-vector.
Args:
m -- a (..., 4, 4): rigid 4x4 transform matrix
v -- a (..., 3) vector
"""
return np.einsum("...ij, ...j -> ...i", m[..., 0:3, 0:3], v) + m[..., 0:3, 3]
This is a slightly edited version from a robotics codebase. Note how much even this trivial function has to do to successfully broadcast in both m
and v
.
if we do the same thing but with numpy.vectorize
, we will have cleaner code but FAR less performance:
def apply_rigid_single(m, v):
"""Apply a rigid 4x4 transform matrix to a 3-vector.
Args:
m -- a (4, 4): rigid 4x4 transform matrix
v -- a (3,) vector
"""
return m[0:3, 0:3] @ v + m[0:3, 3]
apply_rigid_single_np_vectorize = np.vectorize(
apply_rigid_single,
signature="(4,4),(3)->(3)"
)
Let’s measure it:
# Multiply a single vector by 100 000 matrices
m = (np.eye(4, 4) + np.zeros((100_000, 1, 1))).astype(np.float32)
v = np.ones(3, np.float32)
%timeit apply_rigid(m, v)
# --> 854 μs ± 3.47 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
%timeit apply_rigid_single_np_vectorize(m, v)
# --> 136 ms ± 1.64 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
The np.vectorize
version is over 100 times slower! (Numbers on an M4 macbook pro; YMMV).
New way: scalar(ish) functions
Jax looks and feels a lot like numpy, and this is really good (familiar API). Don’t let that fool you: the performance tradeoffs are completely different. If you take a Jax function written for scalars and 1D vectors, and run it inside a jax.numpy.vectorize
, it will run at the same speed as a broadcastable function.
def jax_apply_rigid(m, v):
"""Apply a rigid 4x4 transform matrix to a 3-vector.
Args:
m -- a (..., 4, 4): rigid 4x4 transform matrix
v -- a (..., 3) vector
"""
return jnp.einsum("...ij, ...j -> ...i", m[..., 0:3, 0:3], v) + m[..., 0:3, 3]
def jax_apply_rigid_single(m, v):
"""Apply a rigid 4x4 transform matrix to a 3-vector.
Args:
m -- a (4, 4): rigid 4x4 transform matrix
v -- a (3,) vector
"""
return m[0:3, 0:3] @ v + m[0:3, 3]
jax_apply_rigid_single_np_vectorize = jnp.vectorize(
jax_apply_rigid_single,
signature="(4,4),(3)->(3)"
)
jax_m = (jnp.eye(4, 4) + jnp.zeros((100_000, 1, 1))).astype(np.float32)
jax_v = jnp.ones(3, np.float32)
%timeit jax_apply_rigid(jax_m, jax_v)
# --> 814 μs ± 27.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
%timeit jax_apply_rigid_single_np_vectorize(jax_m, jax_v)
# --> 917 μs ± 13.1 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
The two versions have approximately the same performance, and this didn’t even need JIT compilation! footnote: JIT Unlike numpy, where the vectorized function is actually physically called for each element, Jax interprets the vectorized call using tracers even in non-compiled mode and executes the python function only once.
Why is this important?
- For simplicity - not having to think about the extra dimensions frees up some cognitive load from the coding. Contrast
jnp.einsum("...ij, ...j -> ...i", m[..., 0:3, 0:3], v) + m[..., 0:3, 3]
with
m[0:3, 0:3] @ v + m[0:3, 3]
Broadcasting doesn’t let us use the
@
operator, among other things. The latter code is so much easier to write and read. - For strict and robust checks of input parameters. We can add asserts to check the input shapes:
assert m.shape == (4, 4), m.shape assert v.shape == (3,), v.shape
This lets us catch
# This is the important example to avoid for robustness: # the @ operator inside the function would # silently do the wrong thing. jax_apply_rigid_single(jnp.eye(4), jnp.ones((4, 3))) # --> assertion error
You can also use
jaxtyping
to both document the dimensions and automatically typecheck them at runtime when the function is called. (Note that typeguard must be pinned to version 2.13.3)from jaxtyping import Array, Float from typeguard import typechecked @typechecked def jax_apply_rigid_single( m: Float[Array, "4 4"], v: Float[Array, "3"] ) -> Float[Array, "3"]: """Apply a rigid 4x4 transform matrix to a 3-vector with type checking.""" return m[0:3, 0:3] @ v + m[0:3, 3] # This will pass type checking result = jax_apply_rigid_single(jnp.eye(4), jnp.ones(3)) # This is the important example to avoid for robustness: # the @ operator would silently do the wrong thing. jax_apply_rigid_single(jnp.eye(4), jnp.ones((4, 3))) # --> TypeError: type of argument "v" must be jaxtyping.Float[Array, '3']; # got jaxlib.xla_extension.ArrayImpl instead
(the error message is not quite as good as one might hope for but still…)
After defining the primitive function like this, we can use jax.vmap
or jnp.vectorize
footnote: pytrees to enable full broadcasting for your function.
That brings us to a further subtle point: Quite often, the first instinct is to reach for jnp.vectorize
to make the function fully broadcasting. But wait (like R1). The above style tip also applies here, recursively: depending on the function, you might want to not vectorize it and use vmap
or vectorize
in the client code instead.
Maybe it’s worth providing a non-vectorized (strict) version and a vectorized version separately; I haven’t yet converged on a specific answer.
Why I find this interesting
I think it’s fascinating how the underlying reality impinges on higher-level practices (leaky abstractions all the way down). There is some analogy to the situation where you are designing hardware using VHDL: even though it looks like software, it’s actually building physical gates, and that changes the tradeoffs: instead of designing for the shortest common path, you have to design for the shortest most difficult path (between flipflops) since that will determine your achievable clock frequency.
Please drop me a line if you have other interesting examples :)
Footnotes
Jax and torch
For most customer projects, things seem to converge to pytorch due to the pretrained model ecosystem - in my personal projects, I’m usually doing something more intricate with the computation so Jax’s flexible autograd and JIT fit the bill very nicely.
JIT
With JIT, it’s even faster and there is no difference between the vectorized and non-vectorized versions:
jit_jax_apply_rigid = jax.jit(jax_apply_rigid)
jit_jax_apply_rigid_single_np_vectorize = jax.jit(
jax_apply_rigid_single_np_vectorize
)
%timeit jit_jax_apply_rigid(jax_m, jax_v)
# --> 107 μs ± 2.06 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%timeit jit_jax_apply_rigid_single_np_vectorize(jax_m, jax_v)
# --> 108 μs ± 2.3 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
That’s … fast. And this is still on the CPU.
Pytrees
It’s good to be aware of the fact that jnp.vectorize
does not support pytrees. A version that would allow a good SoA (structure of arrays) signature would be really nice to have.
For pytrees, jax.vmap
is the only game in town.
Enjoy Reading This Article?
Here are some more articles you might like to read next: