patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.1k stars 141 forks source link

`nn.Linear` does not support broadcasting for the bias term #662

Open ariG23498 opened 8 months ago

ariG23498 commented 8 months ago
key = jax.random.PRNGKey(0)

# (B, C, H*W)
inputs = jnp.ones((3, 16, 64))
linear_layer = jax.vmap(eqx.nn.Linear(
    16, 16, use_bias=False, key=key
))

outputs = linear_layer(inputs)
print(outputs.shape)

The above mentioned code works! Notice here that the use_bias term is switched off.

key = jax.random.PRNGKey(0)

# (B, C, H*W)
inputs = jnp.ones((3, 16, 64))
linear_layer = jax.vmap(eqx.nn.Linear(
    16, 16, use_bias=True, key=key
))

outputs = linear_layer(inputs)
print(outputs.shape)

Here the code breaks, complains about the shape incompatiblity for bias addition. I could not figure out a way to by-pass this using the out_features parameter as well.

@soumik12345 pointed me to the FLAX source code where the bias term is explicitly broadcasted. https://github.com/google/flax/blob/daf06eadab6c9e9bfb64b30d0623245179965155/flax/experimental/nnx/nnx/nn/linear.py#L355

@patrick-kidger I would love to contribute this to the codebase. Also, could you point us to the reason why this is the case? We would love to understand why the boradcast was not implicit.

AakashKumarNain commented 8 months ago

@ariG23498 you are making a couple of mistakes here:

  1. As per your linear layer definition, the in_features should be 16 but the dimension in your input is of 64.
  2. Every time you add another dimension, you need to apply vmap all the way.

Here is the working example in your case:


in_features = 64
out_features = 16
batch_size = 3
linear = eqx.nn.Linear(in_features, out_features, key=key, use_bias=True)

inputs = jnp.asarray(np.random.rand(batch_size, 16, 64))
print("Inputs shape: ", inputs.shape)

out = jax.vmap(jax.vmap(linear))(inputs)
print("Outputs shape: ", out.shape)

# Inputs shape:  (3, 16, 64)
# Outputs shape:  (3, 16, 16)
ariG23498 commented 8 months ago
  1. As per your linear layer definition, the in_features should be 16 but the dimension in your input is of 64.

Interesting! I thought equinox was channels first?

  1. Every time you add another dimension, you need to apply vmap all the way.

This is very tedious. How would that play when you build a model with a bunch of layers inside and apply vmap to the model?

AakashKumarNain commented 8 months ago

Interesting! I thought equinox was channels first?

Umm looks like you are mixing two things here. For a linear layer, the last dimension of the input is the number of input features

This is very tedious. How would that play when you build a model with a bunch of layers inside and apply vmap to the model?

You vmap over the entire model. I will put another example

patrick-kidger commented 8 months ago

Hey there! So the fact that your first example works at all is arguably a bug -- the input to an eqx.nn.Linear layer should be a vector, i.e. something with shape (in_features,).

Indeed the intended approach is to use jax.vmap. You don't have to do this for every layer! Write your whole model as acting on a single batch element, and then vmap the whole thing in one go. See the following example, where I'm additionally using jaxtyping to help emphasise the shapes of the arrays:

import equinox as eqx
import jax
import jax.random as jr
from jaxtyping import Array, Float, PRNGKeyArray

class TwoLayerMLP(eqx.Module):
    layer1: eqx.nn.Linear
    layer2: eqx.nn.Linear

    def __init__(self, input_size: int, hidden_size: int, output_size: int, key: PRNGKeyArray):
        key1, key2 = jr.split(key, 2)
        self.layer1 = eqx.nn.Linear(input_size, hidden_size, key=key1)
        self.layer2 = eqx.nn.Linear(hidden_size, output_size, key=key2)

    def __call__(self, x: Float[Array, "input_size"]) -> Float[Array, "output_size"]:
        x = self.layer1(x)
        x = jax.nn.relu(x)
        x = self.layer2(x)
        return x

key = jr.key(0)
data_key, model_key = jr.split(key, 2)
batch_size = 32
input_size = 4
hidden_size = 16
output_size = 4
batch_of_x = jr.normal(data_key, (batch_size, input_size))
model = TwoLayerMLP(input_size, hidden_size, output_size, key=model_key)
jax.vmap(model)(batch_of_x)  # Here's the magic!
ariG23498 commented 8 months ago

Thanks for the code!

I am interested in passing a tensor of shape B, N, in_features to a Linear layer (B is the batch size, N is the number of tokens). How would I possibly do this? TIA!

patrick-kidger commented 8 months ago

Call jax.vmaps twice. These can either be inside or outside the model definition, whichever produces simpler code for you. (See how my vmap above was outside.)

ariG23498 commented 8 months ago

Call jax.vmaps twice. These can either be inside or outside the model definition, whichever produces simpler code for you. (See how my vmap above was outside.)

I am sorry to bother you so much, but does this API design conform with the other frameworks? I am trying to wrap my head around the idea that Linear can only accept a vector, the matrix multiplication operation works, we are just being tied because of the bias term not being broadcasted.

patrick-kidger commented 8 months ago

I believe Equinox is the only library that does things this way. I might even call it an "innovation" in Equinox -- which is possibly the most recently-created of the mainstream NN libraries -- but I think that's a bit of a grandiose term for something so simple!

From a user point of view I don't think it changes very much: with Equinox, you add a single jax.vmap outside your model. In return, you get to write simpler model code, because you only have to think about how your model acts on a single batch element, rather than all of them at once.

(More broadly this gets a conceptual design choice, which is that in the presence of vmap, I believe it becomes an antipattern to write rank-polymorphic code. With vmap, the base operation can always be explicitly batched out to the appropriate rank, without any risk of silent errors due to broadcasting etc.)

ariG23498 commented 8 months ago

That was a beautiful explanation. I think this might be helpful if it were documented somewhere (if I missed it, could you please guide me to the docs)?

Please feel free to close this issue!

patrick-kidger commented 8 months ago

That's great! I'm glad that's made sense.

We have this FAQ entry; I'd be happy to hear any feedback on how this might be clarified! :)

AakashKumarNain commented 8 months ago

but does this API design conform with the other frameworks?

@ariG23498 I would say this conforms with the JAX philosophy especially for vmap: Write it like you are using a single example, and use vmap to run it for a batch