google / objax

Apache License 2.0
768 stars 77 forks source link

replacing jax.vmap with objax.Vectorize #235

Closed mathDR closed 2 years ago

mathDR commented 2 years ago

I am trying to replace using jax.vmap with objax.Vectorize (since mixing these types of operations can cause problems ) and am running into issues. When the function I am calling Vectorize on is an objax module, everything works fine (examples are great in the test code).

My issue is in trying to replace something like the following:

X0 = objax.random.uniform((5, 3))
T = jax.vmap(jnp.diag)(X0)

where this functionality does something similar to tf.diag_part and returns a jnp.array of size (5,3,3) (i.e. 5 3x3 matrices having diagonals of the corresponding X0).

Running

objax.Vectorize(jnp.diag)(X0)

fails with ValueError: You must supply the VarCollection used by the function f

So my question: how can I pass in the VarCollection for this example? Is that possible?

AlexeyKurakin commented 2 years ago

objax.Vectorize does require variable collection if you vectorizing callable. In your case jnp.diag does not use any variables, thus you can just pass empty variable collection which is created as objax.VarCollection().

On top of it, by default objax.Vectorize will try to vectorize function over all arguments. jnp.diag has two arguments, first one is the input, second one k is an integer indication diagonal. So you only can vectorize jnp.diag over first argument. This could be either achieved by using lambda x: jnp.diag(x, k=0) instead of jnp.diag or by providing extra batch_axis argument to objax.Vectorize.

Below are two versions of the code which demonstrate this.

Here is the one version of the code which uses lambda:

x0 = objax.random.uniform((5, 3))

vec_diag  = objax.Vectorize(lambda x: jnp.diag(x, k=0), objax.VarCollection())
t = vec_diag(x0)

Here is another version which uses batch_axis arguments and passes k to vectorized function:

x0 = objax.random.uniform((5, 3))

vec_diag2  = objax.Vectorize(jn.diag, objax.VarCollection(), batch_axis=(0, None))
t = vec_diag2(x0, 0)

Let me know if you have any other questions

AlexeyKurakin commented 2 years ago

I'm closing this for now. Feel free to re-open if there are any follow up questions.

mathDR commented 2 years ago

Thanks! This solved it. I guess I was surprised that I had to explicitly add default parameters (k=0). Appreciate the help!