What is the most performant, idiomatic way of mapping over arrays of functions in JAX? Context: This GitHub issue shows a way to apply vmap to several functions using lax.switch. The example is reproduced below: from jax import lax, vmap import jax.numpy as jnp def func1(x): return 2 * x def func2(x): return -2 * x def func3(x): return 0 * x functions = [func1, func2, func3] index = jnp.arange(len(functions)) x = jnp.ones((3, 5)) vmap_functions = vmap(lambda i, x: lax.switch(i, functions, x)) vmap_functions(index, x) # DeviceArray([[ 2., 2., 2., 2., 2.], # [-2., -2., -2., -2., -2.], # [ 0., 0., 0., 0., 0.]], dtype=float32) My specific questions are: Is this (currently) the most idiomatic way of mapping over arrays of functions in JAX? What performance penalties, if any, does this method incur? (This refers to both runtime and/or compile-time performance.) Continue reading...