Closed snarayan21 closed 3 months ago
Can we offer a flag to gate as well? IIRC there are occasionally numerics issues for long seq...
@ShashankMosaicML do u remember
Flash attention fixed the long seq issue in this PR: https://github.com/Dao-AILab/flash-attention/commit/c79de85ffa0d19b80fa468f90c5086e837499d72
Seeing the error below on CPU tests:
> assert input.is_cuda and target.is_cuda, "Only support CUDA tensors"
E AssertionError: Only support CUDA tensors
So i'm gonna add a check for torch.cuda.is_available()
jk. The torch 3.11 cpu tests were using the cuda image on accident, causing this problem. It was only the torch 3.11 tests too. Fixed that in this PR as well.
Added manual test names to PR description
What does this PR do?
Resubmission of #3214 -- using FA's CE Loss results in lower peak reserved memory usage and higher throughput. We are not adding flash attention as an optional dependency to composer since this makes installs and correct builds messy & take a lot longer.
Fixed a small typo where the torch 3.11 CPU tests were using the GPU image with flash attn installed by accident.
Also modified
DeviceGPU
class so that it instantiates agloo
backend for CPU tensors, ifgloo
is available. This handles cases where users may want to perform distributed operations with tensors present on CPU even if they are using GPUs.Manual tests:
13b-dense-fsdp-fullshard-hsdp-adam-shardedckpt-start-5PtEdK
), resumed with this branch (13b-dense-fsdp-fullshard-hsdp-adam-shardedckpt-resume-E5SieL
)13b-dense-fsdp-fullshard-hsdp-adam-shardedckpt-start-0g8uD4
), resumed with dev branch (13b-dense-fsdp-fullshard-hsdp-adam-shardedckpt-resume-TSGoUC
)4th time's the charm :0
Run with torch CE loss (green): tiny-sp-dtms1-32h-wCFWfa Run with FA CE loss (tan): tiny-sp-dtms1-32h-jOfIPL
What issue(s) does this change relate to?
Before submitting
pre-commit
on your change? (see thepre-commit
section of prerequisites)