Open renos opened 4 years ago
What is index
?
If you want to vmap this function, you'll have to re-express it in terms of compilable functions. In particular, Python control flow based on the values within a JAX array cannot be lowered to XLA; see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Control-Flow for more information.
@jakevdp Index: https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index.html
I understand I need to re-express it in terms of compilable functions... I guess a better question to ask is whether it's possible to implement functions like a depth-first search.
I guess a better question to ask is whether it's possible to implement functions like a depth-first search.
XLA doesn't have data structures like sets or dynamic lists built in, only multi-dimensional arrays. You can probably write this logic in JAX compilable functions, but replacing hash-tables with linear search is not going to be efficient for large graphs.
Other options:
jit
compile is_dag
, but at least you could vmap
it.
I would like to make the following function work with VMAP:
Is it even possible to translate this function given its nonstatic nature over different graphs, and if not, is there any other way of implementing it?