mosaicml / composer

Supercharge Your Model Training
http://docs.mosaicml.com
Apache License 2.0
5.12k stars 415 forks source link

Optionally use `flash-attn`'s CE loss for metrics #3394

Closed snarayan21 closed 3 months ago

snarayan21 commented 3 months ago

What does this PR do?

Resubmission of #3214 -- using FA's CE Loss results in lower peak reserved memory usage and higher throughput. We are not adding flash attention as an optional dependency to composer since this makes installs and correct builds messy & take a lot longer.

Fixed a small typo where the torch 3.11 CPU tests were using the GPU image with flash attn installed by accident.

Also modified DeviceGPU class so that it instantiates a gloo backend for CPU tensors, if gloo is available. This handles cases where users may want to perform distributed operations with tensors present on CPU even if they are using GPUs.

Manual tests:

4th time's the charm :0

Run with torch CE loss (green): tiny-sp-dtms1-32h-wCFWfa Run with FA CE loss (tan): tiny-sp-dtms1-32h-jOfIPL

Screenshot 2024-06-11 at 3 53 53 PM Screenshot 2024-06-11 at 3 54 01 PM

What issue(s) does this change relate to?

Before submitting

ShashankMosaicML commented 3 months ago

Can we offer a flag to gate as well? IIRC there are occasionally numerics issues for long seq...

@ShashankMosaicML do u remember

Flash attention fixed the long seq issue in this PR: https://github.com/Dao-AILab/flash-attention/commit/c79de85ffa0d19b80fa468f90c5086e837499d72

snarayan21 commented 3 months ago

Seeing the error below on CPU tests:

>       assert input.is_cuda and target.is_cuda, "Only support CUDA tensors"
E       AssertionError: Only support CUDA tensors

So i'm gonna add a check for torch.cuda.is_available()

snarayan21 commented 3 months ago

jk. The torch 3.11 cpu tests were using the cuda image on accident, causing this problem. It was only the torch 3.11 tests too. Fixed that in this PR as well.

snarayan21 commented 3 months ago

Added manual test names to PR description