Open ariG23498 opened 9 months ago
@ariG23498 you are making a couple of mistakes here:
in_features
should be 16 but the dimension in your input is of 64. vmap
all the way. Here is the working example in your case:
in_features = 64
out_features = 16
batch_size = 3
linear = eqx.nn.Linear(in_features, out_features, key=key, use_bias=True)
inputs = jnp.asarray(np.random.rand(batch_size, 16, 64))
print("Inputs shape: ", inputs.shape)
out = jax.vmap(jax.vmap(linear))(inputs)
print("Outputs shape: ", out.shape)
# Inputs shape: (3, 16, 64)
# Outputs shape: (3, 16, 16)
- As per your linear layer definition, the
in_features
should be 16 but the dimension in your input is of 64.
Interesting! I thought equinox was channels first?
- Every time you add another dimension, you need to apply
vmap
all the way.
This is very tedious. How would that play when you build a model with a bunch of layers inside and apply vmap
to the model?
Interesting! I thought equinox was channels first?
Umm looks like you are mixing two things here. For a linear layer, the last dimension of the input is the number of input features
This is very tedious. How would that play when you build a model with a bunch of layers inside and apply vmap to the model?
You vmap
over the entire model. I will put another example
Hey there! So the fact that your first example works at all is arguably a bug -- the input to an eqx.nn.Linear
layer should be a vector, i.e. something with shape (in_features,)
.
Indeed the intended approach is to use jax.vmap
. You don't have to do this for every layer! Write your whole model as acting on a single batch element, and then vmap the whole thing in one go. See the following example, where I'm additionally using jaxtyping to help emphasise the shapes of the arrays:
import equinox as eqx
import jax
import jax.random as jr
from jaxtyping import Array, Float, PRNGKeyArray
class TwoLayerMLP(eqx.Module):
layer1: eqx.nn.Linear
layer2: eqx.nn.Linear
def __init__(self, input_size: int, hidden_size: int, output_size: int, key: PRNGKeyArray):
key1, key2 = jr.split(key, 2)
self.layer1 = eqx.nn.Linear(input_size, hidden_size, key=key1)
self.layer2 = eqx.nn.Linear(hidden_size, output_size, key=key2)
def __call__(self, x: Float[Array, "input_size"]) -> Float[Array, "output_size"]:
x = self.layer1(x)
x = jax.nn.relu(x)
x = self.layer2(x)
return x
key = jr.key(0)
data_key, model_key = jr.split(key, 2)
batch_size = 32
input_size = 4
hidden_size = 16
output_size = 4
batch_of_x = jr.normal(data_key, (batch_size, input_size))
model = TwoLayerMLP(input_size, hidden_size, output_size, key=model_key)
jax.vmap(model)(batch_of_x) # Here's the magic!
Thanks for the code!
I am interested in passing a tensor of shape B, N, in_features
to a Linear layer (B is the batch size, N is the number of tokens). How would I possibly do this? TIA!
Call jax.vmap
s twice. These can either be inside or outside the model definition, whichever produces simpler code for you. (See how my vmap above was outside.)
Call
jax.vmap
s twice. These can either be inside or outside the model definition, whichever produces simpler code for you. (See how my vmap above was outside.)
I am sorry to bother you so much, but does this API design conform with the other frameworks? I am trying to wrap my head around the idea that Linear
can only accept a vector, the matrix multiplication operation works, we are just being tied because of the bias term not being broadcasted.
I believe Equinox is the only library that does things this way. I might even call it an "innovation" in Equinox -- which is possibly the most recently-created of the mainstream NN libraries -- but I think that's a bit of a grandiose term for something so simple!
From a user point of view I don't think it changes very much: with Equinox, you add a single jax.vmap
outside your model. In return, you get to write simpler model code, because you only have to think about how your model acts on a single batch element, rather than all of them at once.
(More broadly this gets a conceptual design choice, which is that in the presence of vmap
, I believe it becomes an antipattern to write rank-polymorphic code. With vmap
, the base operation can always be explicitly batched out to the appropriate rank, without any risk of silent errors due to broadcasting etc.)
That was a beautiful explanation. I think this might be helpful if it were documented somewhere (if I missed it, could you please guide me to the docs)?
Please feel free to close this issue!
That's great! I'm glad that's made sense.
We have this FAQ entry; I'd be happy to hear any feedback on how this might be clarified! :)
but does this API design conform with the other frameworks?
@ariG23498 I would say this conforms with the JAX philosophy especially for vmap: Write it like you are using a single example, and use vmap to run it for a batch
The above mentioned code works! Notice here that the
use_bias
term is switched off.Here the code breaks, complains about the shape incompatiblity for bias addition. I could not figure out a way to by-pass this using the
out_features
parameter as well.@soumik12345 pointed me to the FLAX source code where the
bias
term is explicitly broadcasted. https://github.com/google/flax/blob/daf06eadab6c9e9bfb64b30d0623245179965155/flax/experimental/nnx/nnx/nn/linear.py#L355@patrick-kidger I would love to contribute this to the codebase. Also, could you point us to the reason why this is the case? We would love to understand why the boradcast was not implicit.