lucidrains / ring-attention-pytorch

Implementation of 💍 Ring Attention, from Liu et al. at Berkeley AI, in Pytorch
MIT License
452 stars 26 forks source link

Working towards CUDA use with NCCL backend #9

Closed Damjan-Kalajdzievski closed 5 months ago

Damjan-Kalajdzievski commented 6 months ago

ring_flash_attn_cuda was mistyped as ring_flash_attention_cuda in the import

EDIT: Changing this PR to draft as further changes are necessary, and it is not yet working with CUDA


Description

This PR seeks to get the codebase running with cuda. We use a minimally modified version of assert.py, "assert_cuda.py", as a test file to test if the codebase can be run on 8 gpus. Note: you must specify the environment variable CUDA_VISIBLE_DEVICES eg:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python assert_cuda.py
Environment info running on a single node with 8x40GB A100s ``` absl-py==2.0.0 accelerate==0.23.0 aiohttp==3.9.1 aiosignal==1.3.1 annotated-types==0.6.0 anthropic==0.9.0 anyio==4.2.0 appdirs==1.4.4 asttokens==2.4.1 async-timeout==4.0.3 attrs==23.2.0 beartype==0.17.2 bitsandbytes==0.41.2.post2 cachetools==5.3.2 certifi==2023.11.17 chardet==5.2.0 charset-normalizer==3.3.2 chex==0.1.85 click==8.1.7 colorama==0.4.6 comm==0.2.1 contourpy==1.2.0 cycler==0.12.1 DataProperty==1.0.1 datasets==2.14.6 debugpy==1.8.0 decorator==5.1.1 deepspeed==0.12.2 dill==0.3.7 distro==1.9.0 docker-pycreds==0.4.0 docstring-parser==0.15 einops==0.7.0 einx==0.1.3 etils==1.5.2 evaluate==0.4.0 exceptiongroup==1.2.0 executing==2.0.1 fastapi==0.109.0 filelock==3.13.1 flash-attn==2.5.2 flax==0.7.0 fonttools==4.47.2 frozendict==2.4.0 frozenlist==1.4.1 fsspec==2023.10.0 gitdb==4.0.11 GitPython==3.1.41 google-auth==2.26.1 google-auth-oauthlib==1.2.0 grpcio==1.60.0 h11==0.14.0 hjson==3.1.0 httpcore==1.0.2 httpx==0.26.0 huggingface-hub==0.20.2 idna==3.6 importlib-metadata==7.0.1 importlib-resources==6.1.1 ipykernel==6.29.0 ipython==8.18.1 jax==0.4.25 jaxlib==0.4.25 jedi==0.19.1 Jinja2==3.1.2 joblib==1.3.2 jsonlines==4.0.0 jsonschema==4.20.0 jsonschema-specifications==2023.12.1 jupyter_client==8.6.0 jupyter_core==5.7.1 kiwisolver==1.4.5 lxml==5.1.0 Markdown==3.5.2 markdown-it-py==3.0.0 markdown2==2.4.12 MarkupSafe==2.1.3 matplotlib==3.8.2 matplotlib-inline==0.1.6 mbstrdecoder==1.1.3 mdurl==0.1.2 ml-dtypes==0.3.2 mpmath==1.3.0 msgpack==1.0.7 multidict==6.0.4 multiprocess==0.70.15 nest-asyncio==1.5.9 networkx==3.2.1 nh3==0.2.15 ninja==1.11.1.1 nltk==3.8.1 numexpr==2.8.8 numpy==1.26.3 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==8.9.2.26 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.18.1 nvidia-nvjitlink-cu12==12.3.101 nvidia-nvtx-cu12==12.1.105 oauthlib==3.2.2 openai==0.28.1 opt-einsum==3.3.0 optax==0.1.9 orbax-checkpoint==0.5.3 packaging==23.2 pandas==2.1.4 parso==0.8.3 pathvalidate==3.2.0 peft==0.7.1 pexpect==4.9.0 pillow==10.2.0 platformdirs==4.1.0 portalocker==2.8.2 prompt-toolkit==3.0.43 protobuf==3.20.2 psutil==5.9.7 ptyprocess==0.7.0 pure-eval==0.2.2 py-cpuinfo==9.0.0 pyarrow==14.0.2 pyasn1==0.5.1 pyasn1-modules==0.3.0 pybind11==2.11.1 pydantic==1.10.13 pydantic_core==2.14.6 Pygments==2.17.2 pynvml==11.5.0 pyparsing==3.1.1 pytablewriter==1.2.0 python-dateutil==2.8.2 pytz==2023.3.post1 PyYAML==6.0.1 pyzmq==25.1.2 ray==2.9.0 referencing==0.32.1 regex==2023.12.25 requests==2.31.0 requests-oauthlib==1.3.1 responses==0.18.0 rich==13.7.0 rouge-score==0.1.2 rpds-py==0.16.2 rsa==4.9 sacrebleu==2.4.0 safetensors==0.4.1 scikit-learn==1.3.2 scipy==1.11.4 seaborn==0.13.2 sentencepiece==0.1.99 sentry-sdk==1.39.2 setproctitle==1.3.3 shortuuid==1.0.11 shtab==1.6.5 six==1.16.0 smmap==5.0.1 sniffio==1.3.0 sqlitedict==2.1.0 stack-data==0.6.3 starlette==0.35.1 svgwrite==1.4.3 sympy==1.12 tabledata==1.3.3 tabulate==0.9.0 tcolorpy==0.1.4 tensorboard==2.15.1 tensorboard-data-server==0.7.2 tensorstore==0.1.53 threadpoolctl==3.2.0 tiktoken==0.5.2 tokenizers==0.15.0 toolz==0.12.1 torch==2.1.2 tornado==6.4 tqdm==4.66.1 tqdm-multiprocess==0.0.11 traitlets==5.14.1 transformers==4.36.2 triton==2.1.0 trl==0.7.7 typepy==1.3.2 typing_extensions==4.9.0 tyro==0.6.3 tzdata==2023.4 urllib3==2.1.0 uvicorn==0.25.0 wandb==0.16.2 wavedrom==2.0.3.post3 wcwidth==0.2.13 Werkzeug==3.0.1 xxhash==3.4.1 yarl==1.9.4 zipp==3.17.0 zstandard==0.22.0 ```

State of changes before rebase on 0.2.9

After some minor changes to get things kicked off (eg missing import, typos, ...), we have that the ring attention transformer forwards pass hangs in send_and_receive_ on dist.recv(receive_buffer, receive_from_rank), and we find that replacing the isend/irecv requests in send_andreceive with batched ones fixes this (see the comment below).

With the batched request changed and turned on, the ring attention transformer fails in the flash attention backwards pass with:

File "/home/ubuntu/work/ring-attention-pytorch/ring_attention_pytorch/ring_flash_attention_cuda.py", line 759, in backward
    ring_dq, ring_dk, ring_dv, *_ = _flash_attn_backward(
  File "/home/ubuntu/work/ring_mlp/ringmlpvenv/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 133, in _flash_attn_backward
    dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
RuntimeError: out must have shape (batch_size, seqlen_q, num_heads, head_size)

When striped_ring_attn=False is set in assert_cuda.py, we find that the ring backwards pass returns from _flash_attn_backward succesfully, but hangs at an undetermined later point.

State as of 5532b7b

both striped and non-striped ring attention get through the computation but are not close in terms of forward pass outputs.

torch.allclose(ring_out, flash_out, atol=1e-6), "output is not the same"
> False
a = ring_out - flash_out
a.max()
> tensor(0.6748, grad_fn=<MaxBackward1>)
a.mean()
> tensor(0.0011, grad_fn=<MeanBackward0>)

However the gradients are close, with a max difference on the order of 0.0004

lucidrains commented 6 months ago

ah nice find, yea i ran out of time before testing the ring cuda integration with attention

want to bump the patch version and i'll release it?

Damjan-Kalajdzievski commented 6 months ago

Actually, do you want me to get the code running using 8 A100's with the NCCL backend first before bumping the version? I can batch the PR's however you'd like. I haven't been able to get it running yet, and found a few other things, for example:

Another change I had to make was to replace the isend/irecv requests in send_and_receive_ with batched ones:

def send_and_receive_(x, receive_buffer, send_to_rank, receive_from_rank):
    _ops = []
    send_op = dist.P2POp(dist.isend, x, send_to_rank)
    recv_op = dist.P2POp(dist.irecv, receive_buffer, receive_from_rank)
    _ops.append(send_op)
    _ops.append(recv_op)

    reqs = dist.batch_isend_irecv(_ops)

    for req in reqs:
        req.wait()
    dist.barrier()

Otherwise, the code would hang on the dist.isend request. Not sure if this is specific to my machine setup, but it does seem most codebases seem to prefer the dist.batch_isend_irecv, so maybe it's less buggy.

lucidrains commented 6 months ago

@Damjan-Kalajdzievski sure! i didn't use the batched version as for the autoregressive case i just concatted the key / values and sent them in one go, but if somehow it is more mature, let's go with that (once you tested it of course)

better yet, keep the logic for both and introduce a flag that allows for one or the other. i went through some github issue a few weeks ago and it felt like the batched version was less mature on quick perusal

Damjan-Kalajdzievski commented 6 months ago

@lucidrains Sorry for the delay, but I got stuck after running into some hard to debug issues. I have changed this PR to a draft which hopefully progresses to tackling these issues, which I describe in the edited first comment.

lucidrains commented 6 months ago

@Damjan-Kalajdzievski thanks for the time you've put into this

could you try 0.2.9 and see how that fares for both striped and unstriped?

Damjan-Kalajdzievski commented 6 months ago

Cool, did not see those changes, I rebased this branch on them.

For striped I get the same error during the ring attention backward in flash_attn_cuda.bwd

dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
RuntimeError: dout must have shape (batch_size, seqlen_q, num_heads, head_size_og))

and for non striped I get through the computation, but the allclose assert check fails for the output

assert torch.allclose(ring_out, flash_out, atol=1e-6), "output is not the same"

When I examine how close the two tensors are, I get

a = ring_out - flash_out
a.max()
> tensor(0.6748, grad_fn=<MaxBackward1>)
a.mean()
> tensor(0.0011, grad_fn=<MeanBackward0>)

The backwards pass grad check does return that the grads are close though.

torch.allclose(ring_embed_grad, flash_embed_grad, atol=1e-2)
> True

With the max distance being 0.0004

lucidrains commented 6 months ago

@Damjan-Kalajdzievski ok, try yet again for striped? and yea, i'm sure it must be something small, seems close. you may need to help debug this, as i am in the midst of interviews etc

Damjan-Kalajdzievski commented 6 months ago

Great, that brought striped attention in line with what is happening for non-striped (ie the allclose of the forwards pass fails with about the same magnitute, but the gradients closeness check passes). It is merged in. Yes I should be able to debug the forwards pass, I will get to that after some other things I have to prioritize as well. Thanks!

lucidrains commented 6 months ago

Thank you Damjan 🙏

edit: so the strange thing is that the gradients should be unlikely to be the same without the output matching. maybe a clue

lucidrains commented 5 months ago

@Damjan-Kalajdzievski hey Damjan, thanks for the PR once again

finally found some off time to test this out myself, and I believe the non-striped causal version is working fine. it is the striped version that is broken (only for output and not gradient for some reason, like you noted), and i will probably get to the bottom of that by end of week. could you verify that non-striped passes on your end?

edit: ah, so the gradients are also failing for striped causal version of ring + flash cuda. the difference was just too small to be detected

lucidrains commented 5 months ago

going to close this PR, as the main branch now contains most of the fixes

lucidrains commented 5 months ago

@Damjan-Kalajdzievski found the offending issue with stripe ring attention. turns out at the moment it only works for 1 bucket per machine

added an assert for now, but will plan on supporting multiple buckets

lucidrains commented 5 months ago

@Damjan-Kalajdzievski all tests should pass now with multiple cuda devices over NCCL; please open an issue if it doesn't

lucidrains commented 5 months ago

@Damjan-Kalajdzievski decided to just power through this before my next wave of interviews. should all work now

Damjan-Kalajdzievski commented 5 months ago

Awesome, thanks for the fixes! I tested the cuda striped and non-striped on main just now and it works for me too. Sorry I had not gotten the time yet before you got back to it; I had to prioritize other work for some weeks.

lucidrains commented 5 months ago

@Damjan-Kalajdzievski nice, good to hear! yes, it is understandable.. no worries