1. Anuncie Aqui ! Entre em contato fdantas@4each.com.br

[Python] Mapping Over Arrays of Functions in JAX

Discussão em 'Python' iniciado por Stack, Setembro 13, 2024.

  1. Stack

    Stack Membro Participativo

    What is the most performant, idiomatic way of mapping over arrays of functions in JAX?

    Context: This GitHub issue shows a way to apply vmap to several functions using lax.switch. The example is reproduced below:

    from jax import lax, vmap
    import jax.numpy as jnp

    def func1(x):
    return 2 * x

    def func2(x):
    return -2 * x

    def func3(x):
    return 0 * x

    functions = [func1, func2, func3]
    index = jnp.arange(len(functions))
    x = jnp.ones((3, 5))

    vmap_functions = vmap(lambda i, x: lax.switch(i, functions, x))
    vmap_functions(index, x)
    # DeviceArray([[ 2., 2., 2., 2., 2.],
    # [-2., -2., -2., -2., -2.],
    # [ 0., 0., 0., 0., 0.]], dtype=float32)


    My specific questions are:

    • Is this (currently) the most idiomatic way of mapping over arrays of functions in JAX?
    • What performance penalties, if any, does this method incur? (This refers to both runtime and/or compile-time performance.)

    Continue reading...

Compartilhe esta Página