Closed mathDR closed 2 years ago
objax.Vectorize
does require variable collection if you vectorizing callable. In your case jnp.diag
does not use any variables, thus you can just pass empty variable collection which is created as objax.VarCollection()
.
On top of it, by default objax.Vectorize
will try to vectorize function over all arguments. jnp.diag
has two arguments, first one is the input, second one k
is an integer indication diagonal. So you only can vectorize jnp.diag
over first argument. This could be either achieved by using lambda x: jnp.diag(x, k=0)
instead of jnp.diag
or by providing extra batch_axis
argument to objax.Vectorize
.
Below are two versions of the code which demonstrate this.
Here is the one version of the code which uses lambda:
x0 = objax.random.uniform((5, 3))
vec_diag = objax.Vectorize(lambda x: jnp.diag(x, k=0), objax.VarCollection())
t = vec_diag(x0)
Here is another version which uses batch_axis
arguments and passes k
to vectorized function:
x0 = objax.random.uniform((5, 3))
vec_diag2 = objax.Vectorize(jn.diag, objax.VarCollection(), batch_axis=(0, None))
t = vec_diag2(x0, 0)
Let me know if you have any other questions
I'm closing this for now. Feel free to re-open if there are any follow up questions.
Thanks! This solved it. I guess I was surprised that I had to explicitly add default parameters (k=0). Appreciate the help!
I am trying to replace using
jax.vmap
withobjax.Vectorize
(since mixing these types of operations can cause problems ) and am running into issues. When the function I am callingVectorize
on is anobjax
module, everything works fine (examples are great in the test code).My issue is in trying to replace something like the following:
where this functionality does something similar to
tf.diag_part
and returns ajnp.array
of size(5,3,3)
(i.e. 5 3x3 matrices having diagonals of the correspondingX0
).Running
fails with
ValueError: You must supply the VarCollection used by the function f
So my question: how can I pass in the
VarCollection
for this example? Is that possible?