pytorch-labs / attention-gym

Helpful tools and examples for working with flex-attention
BSD 3-Clause "New" or "Revised" License
440 stars 22 forks source link

Bias gradient support? #20

Open ardagoreci opened 2 months ago

ardagoreci commented 2 months ago

I was wondering if there are any plans to support gradient flow back to biases that are added to the score function. For instance, if I add a scalar to the: score + b where b is coming from a learnable projection, the gradients should flow back to those layers.

Currently, I get the following error when I try this:

AssertionError: Captured buffers that require grad are not yet supported.

This feature would be very useful for the attention mechanisms in AlphaFold for computational biology (a lot of pair biasing in those attention components).

drisspg commented 2 months ago

Makes sense, we very much plan to support this however there is some complexity on how you generate the bwd for this because broadcasts in the forward become reductions in the backward. We are discussing the implementation now and I will keep you posted

ardagoreci commented 2 months ago

Sounds great, thank you!

eballesteros commented 2 months ago

+1 to this request! It would be very useful for T5-style learned relative positional biases

def learned_relative_positional(score, b, h, q_idx, kv_idx):
    r_pos = (q_idx - kv_idx)
    return score + pos_emb[r_pos]
spinjo commented 2 months ago

Another +1!

ardagoreci commented 1 month ago

Another update, we recently released a preliminary open source implementation of AlphaFold3 here: https://github.com/Ligo-Biosciences/AlphaFold3. The model is still not nearly scalable as it could be and attention ops are memory bound for the most part. I can migrate to FlexAttention as soon as there is bias gradient support!

ds-kczerski commented 1 month ago

Another +1!

I would be interested to have a support for a function like that:

from torch import Tensor
from torch.nn import Embedding

def wrapped_agraph(distances: Tensor, distance_encoders: Embedding):
    """
    wrapping and defining score modification function `score_mod` for flex.

    distances - torch tensor [batch_size, seq_len, seq_len], example: torch.Size([32, 49, 49], int matrix, ints in range <0..12>
    distance_encoders - lernable embeddings, Embedding layer [embeddings_count, heads] (see: from torch.nn import Embedding), example: Embedding(13, 16)
    """
    def agraph(score, b, h, q_idx, kv_idx):
        idx_d = distances[b, q_idx, kv_idx]
        return score + distance_encoders(idx_d)[h]
    return agraph
ViktorooReps commented 3 weeks ago

Another +1

drisspg commented 3 weeks ago

We are actively working on this now, will post back here with further updates