jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.46k stars 2.8k forks source link

Making function vmap'able #4157

Open renos opened 4 years ago

renos commented 4 years ago

I would like to make the following function work with VMAP:

def is_dag(graph):

  L = list()
  S = set()
  for i in range(graph.shape[0]):
    if (jnp.sum(graph[:,i]) == 0):
      S.add(i)
  while (len(S) != 0):
    n = S.pop()
    L.append(n)

    for mpos in range(graph.shape[0]):
      if(graph[n,mpos]):
        graph = index_update(graph, index[n,mpos], 0)
        if (onp.sum(graph[:,mpos]) == 0):
          S.add(mpos)
  return jnp.sum(graph) == 0

Is it even possible to translate this function given its nonstatic nature over different graphs, and if not, is there any other way of implementing it?

jakevdp commented 4 years ago

What is index?

If you want to vmap this function, you'll have to re-express it in terms of compilable functions. In particular, Python control flow based on the values within a JAX array cannot be lowered to XLA; see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Control-Flow for more information.

renos commented 4 years ago

@jakevdp Index: https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index.html

I understand I need to re-express it in terms of compilable functions... I guess a better question to ask is whether it's possible to implement functions like a depth-first search.

shoyer commented 4 years ago

I guess a better question to ask is whether it's possible to implement functions like a depth-first search.

XLA doesn't have data structures like sets or dynamic lists built in, only multi-dimensional arrays. You can probably write this logic in JAX compilable functions, but replacing hash-tables with linear search is not going to be efficient for large graphs.

Other options: