hpcaitech / FastFold

Optimizing AlphaFold Training and Inference on GPU Clusters
Apache License 2.0
556 stars 84 forks source link

triton softmax support multi-batch #152

Open Gy-Lu opened 1 year ago

Gy-Lu commented 1 year ago
  1. Support another batch dimension for softmax. In training or batch inference, we may add a batch dimension as the first dimension of some tensors. However, we use the third dimension(tensor.shape[2]) as the head_dim, which would be influenced. In this pr, I modify it to tensor.shape[-3] to solve this problem. CUDA kernel is modified as well.

  2. Enable test_atten_core, this test is skipped by default and never be used.

Gy-Lu commented 1 year ago

I found it would fail when the batch dimension comes to 2.

Gy-Lu commented 1 year ago

Update: It works incorrectly before this commit. For the way getting bias_ptr not supports multi-batch. And I have fixed the triton version. CUDA version may support multi-batch one day :(