google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.03k stars 637 forks source link

Improve documentation of make_causal_mask #1520

Open rongcuid opened 3 years ago

rongcuid commented 3 years ago

The current make_causal_mask API is overly restrictive because it requires one to have an array of shape [B, *, L], but MultiHeadAttention requires [B, *, L, H] as input. This causes a great amount of confusion (as the examples do not have shape annotation/comments) when created mask contains features in its shape and cause shape mismatch errors down the road.

I propose to change the make_causal_mask to take a shape tuple as input, instead of an example input array.

As a side issue, many functions have docstrings, but they are not in the generated documents. Most notably are all __call__ methods, which I think are very important. I am unfamiliar with Sphinx, so I am not sure what kind of issue this would be.

levskaya commented 3 years ago

make_causal_mask accepts an input with an arbitrary number of batch dimensions, but a single terminal "length" position, as you said [B..., L] it then generates a [B..., 1, L, L] causal mask which will broadcast to the attention shape of [B..., H, L, L] when masking the actual attention matrix.

It accepts inputs rather than a tuple shape spec to follow the API of the other general masking function make_attention_mask which handles creating masks for padding tokens, segmentation masks for packed examples, etc.

I agree that keeping the shapes straight in transformer code is hard - probably we should add some more key shape annotation comments to make things clearer. Thanks for raising the issue of docstrings not getting through in Sphinx - we'll have to investigate that.

rongcuid commented 3 years ago

The thing is, my input shape is [B, L, H], and this requires me to do make_causal_mask(input[:, :, 0]). This can cause trouble because in my input pipeline I filter out bad samples, so if by chance a 0-sized batch comes in it can crash the training... five days after starting.

I do not think make_causal_mask should have the same API as make_attention_mask, because the input value really is not required at all. The general attention mask can be used for padding and probably other things, but causal mask really is just used for self attention and nothing else, and causal mask does not make sense for anything other than square matrices. In fact, the Pytorch example to Transformers use triu matrix as the causal mask, which probably saves a lot of memory, too.

rongcuid commented 3 years ago

I argue this from The Flax Philosophy, which says:

Prefer duplicating code over a bad abstraction.

levskaya commented 3 years ago

Hey - sorry for the delay.

The existing utility fns are not meant to be in any way the "final word" on transformer APIs. One of the problems w. transformers is that these days there are ~100 variants of all sorts that tweak just about everything. So we mainly offer a few super classical self-attention layers to illustrate important general techniques like how to do autoregressive decoding in flax / jax.

I take it you're trying to build the mask inside each layer then? The existing make_causal_mask / make_attention_mask were both designed to be used immediately at the top-level model layer to prepare all the masks at once that were needed for attention layers throughout the model (intentionally: in order to make the compiler's job a bit easier instead of assuming common subexpression elimination will just work). At this point the current API has been around for long enough that a lot of code depends on the existing behavior, so we can't just delete what's there.

We strongly encourage people to just make their own masking utility functions if the example ones we provided aren't the right fit! As you mention it's just a broadcast triu. (A side note - it's not so easy to reason about memory usage with XLA - XLA will aggressively optimize away apparent inefficiencies in user-level code. Computed constants like masks are rarely an issue.)

rongcuid commented 3 years ago

I was actually generating masks only once at top level. However, my training data, from the beginning, is in the shape of [B, L, H]. I am OK with the current implementation, but its documentation probably need some clarification.

marcvanzee commented 2 years ago

Thanks for the feedback @rongcuid, I have renamed this issue to clarify that we should improve the documentation of causal mask.