vmap should accept a dim_size=None argument where the user is allowed to specify the size of the dimension being vmapped over. Should behave similarly to JAX's axis_name argument.
The net effect of this is that one should be able to vmap over functions that do not take Tensors as input!
vmap should accept a dim_size=None argument where the user is allowed to specify the size of the dimension being vmapped over. Should behave similarly to JAX's axis_name argument.
The net effect of this is that one should be able to vmap over functions that do not take Tensors as input!
We should also investigate if there are other things that the axis_size arg in JAX provides.