Open ardagoreci opened 3 months ago
Makes sense, we very much plan to support this however there is some complexity on how you generate the bwd for this because broadcasts in the forward become reductions in the backward. We are discussing the implementation now and I will keep you posted
Sounds great, thank you!
+1 to this request! It would be very useful for T5-style learned relative positional biases
def learned_relative_positional(score, b, h, q_idx, kv_idx):
r_pos = (q_idx - kv_idx)
return score + pos_emb[r_pos]
Another +1!
Another update, we recently released a preliminary open source implementation of AlphaFold3 here: https://github.com/Ligo-Biosciences/AlphaFold3. The model is still not nearly scalable as it could be and attention ops are memory bound for the most part. I can migrate to FlexAttention as soon as there is bias gradient support!
Another +1!
I would be interested to have a support for a function like that:
from torch import Tensor
from torch.nn import Embedding
def wrapped_agraph(distances: Tensor, distance_encoders: Embedding):
"""
wrapping and defining score modification function `score_mod` for flex.
distances - torch tensor [batch_size, seq_len, seq_len], example: torch.Size([32, 49, 49], int matrix, ints in range <0..12>
distance_encoders - lernable embeddings, Embedding layer [embeddings_count, heads] (see: from torch.nn import Embedding), example: Embedding(13, 16)
"""
def agraph(score, b, h, q_idx, kv_idx):
idx_d = distances[b, q_idx, kv_idx]
return score + distance_encoders(idx_d)[h]
return agraph
Another +1
We are actively working on this now, will post back here with further updates
+1 see #69
Will have updates soon :)
Thanks for the update! Will try out when available. Do you have a release timeline? @drisspg
I was wondering if there are any plans to support gradient flow back to biases that are added to the score function. For instance, if I add a scalar to the: score + b where b is coming from a learnable projection, the gradients should flow back to those layers.
Currently, I get the following error when I try this:
AssertionError: Captured buffers that require grad are not yet supported.
This feature would be very useful for the attention mechanisms in AlphaFold for computational biology (a lot of pair biasing in those attention components).