Closed Damjan-Kalajdzievski closed 5 months ago
ah nice find, yea i ran out of time before testing the ring cuda integration with attention
want to bump the patch version and i'll release it?
Actually, do you want me to get the code running using 8 A100's with the NCCL backend first before bumping the version? I can batch the PR's however you'd like. I haven't been able to get it running yet, and found a few other things, for example:
Another change I had to make was to replace the isend/irecv requests in send_and_receive_
with batched ones:
def send_and_receive_(x, receive_buffer, send_to_rank, receive_from_rank):
_ops = []
send_op = dist.P2POp(dist.isend, x, send_to_rank)
recv_op = dist.P2POp(dist.irecv, receive_buffer, receive_from_rank)
_ops.append(send_op)
_ops.append(recv_op)
reqs = dist.batch_isend_irecv(_ops)
for req in reqs:
req.wait()
dist.barrier()
Otherwise, the code would hang on the dist.isend
request. Not sure if this is specific to my machine setup, but it does seem most codebases seem to prefer the dist.batch_isend_irecv
, so maybe it's less buggy.
@Damjan-Kalajdzievski sure! i didn't use the batched version as for the autoregressive case i just concatted the key / values and sent them in one go, but if somehow it is more mature, let's go with that (once you tested it of course)
better yet, keep the logic for both and introduce a flag that allows for one or the other. i went through some github issue a few weeks ago and it felt like the batched version was less mature on quick perusal
@lucidrains Sorry for the delay, but I got stuck after running into some hard to debug issues. I have changed this PR to a draft which hopefully progresses to tackling these issues, which I describe in the edited first comment.
@Damjan-Kalajdzievski thanks for the time you've put into this
could you try 0.2.9
and see how that fares for both striped and unstriped?
Cool, did not see those changes, I rebased this branch on them.
For striped I get the same error during the ring attention backward in flash_attn_cuda.bwd
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
RuntimeError: dout must have shape (batch_size, seqlen_q, num_heads, head_size_og))
and for non striped I get through the computation, but the allclose assert check fails for the output
assert torch.allclose(ring_out, flash_out, atol=1e-6), "output is not the same"
When I examine how close the two tensors are, I get
a = ring_out - flash_out
a.max()
> tensor(0.6748, grad_fn=<MaxBackward1>)
a.mean()
> tensor(0.0011, grad_fn=<MeanBackward0>)
The backwards pass grad check does return that the grads are close though.
torch.allclose(ring_embed_grad, flash_embed_grad, atol=1e-2)
> True
With the max distance being 0.0004
@Damjan-Kalajdzievski ok, try yet again for striped? and yea, i'm sure it must be something small, seems close. you may need to help debug this, as i am in the midst of interviews etc
Great, that brought striped attention in line with what is happening for non-striped (ie the allclose of the forwards pass fails with about the same magnitute, but the gradients closeness check passes). It is merged in. Yes I should be able to debug the forwards pass, I will get to that after some other things I have to prioritize as well. Thanks!
Thank you Damjan 🙏
edit: so the strange thing is that the gradients should be unlikely to be the same without the output matching. maybe a clue
@Damjan-Kalajdzievski hey Damjan, thanks for the PR once again
finally found some off time to test this out myself, and I believe the non-striped causal version is working fine. it is the striped version that is broken (only for output and not gradient for some reason, like you noted), and i will probably get to the bottom of that by end of week. could you verify that non-striped passes on your end?
edit: ah, so the gradients are also failing for striped causal version of ring + flash cuda. the difference was just too small to be detected
going to close this PR, as the main branch now contains most of the fixes
@Damjan-Kalajdzievski found the offending issue with stripe ring attention. turns out at the moment it only works for 1 bucket per machine
added an assert for now, but will plan on supporting multiple buckets
@Damjan-Kalajdzievski all tests should pass now with multiple cuda devices over NCCL; please open an issue if it doesn't
@Damjan-Kalajdzievski decided to just power through this before my next wave of interviews. should all work now
Awesome, thanks for the fixes! I tested the cuda striped and non-striped on main just now and it works for me too. Sorry I had not gotten the time yet before you got back to it; I had to prioritize other work for some weeks.
@Damjan-Kalajdzievski nice, good to hear! yes, it is understandable.. no worries
ring_flash_attn_cuda
was mistyped asring_flash_attention_cuda
in the importEDIT: Changing this PR to draft as further changes are necessary, and it is not yet working with CUDA
Description
This PR seeks to get the codebase running with cuda. We use a minimally modified version of
assert.py
, "assert_cuda.py
", as a test file to test if the codebase can be run on 8 gpus. Note: you must specify the environment variableCUDA_VISIBLE_DEVICES
eg:Environment info
running on a single node with 8x40GB A100s ``` absl-py==2.0.0 accelerate==0.23.0 aiohttp==3.9.1 aiosignal==1.3.1 annotated-types==0.6.0 anthropic==0.9.0 anyio==4.2.0 appdirs==1.4.4 asttokens==2.4.1 async-timeout==4.0.3 attrs==23.2.0 beartype==0.17.2 bitsandbytes==0.41.2.post2 cachetools==5.3.2 certifi==2023.11.17 chardet==5.2.0 charset-normalizer==3.3.2 chex==0.1.85 click==8.1.7 colorama==0.4.6 comm==0.2.1 contourpy==1.2.0 cycler==0.12.1 DataProperty==1.0.1 datasets==2.14.6 debugpy==1.8.0 decorator==5.1.1 deepspeed==0.12.2 dill==0.3.7 distro==1.9.0 docker-pycreds==0.4.0 docstring-parser==0.15 einops==0.7.0 einx==0.1.3 etils==1.5.2 evaluate==0.4.0 exceptiongroup==1.2.0 executing==2.0.1 fastapi==0.109.0 filelock==3.13.1 flash-attn==2.5.2 flax==0.7.0 fonttools==4.47.2 frozendict==2.4.0 frozenlist==1.4.1 fsspec==2023.10.0 gitdb==4.0.11 GitPython==3.1.41 google-auth==2.26.1 google-auth-oauthlib==1.2.0 grpcio==1.60.0 h11==0.14.0 hjson==3.1.0 httpcore==1.0.2 httpx==0.26.0 huggingface-hub==0.20.2 idna==3.6 importlib-metadata==7.0.1 importlib-resources==6.1.1 ipykernel==6.29.0 ipython==8.18.1 jax==0.4.25 jaxlib==0.4.25 jedi==0.19.1 Jinja2==3.1.2 joblib==1.3.2 jsonlines==4.0.0 jsonschema==4.20.0 jsonschema-specifications==2023.12.1 jupyter_client==8.6.0 jupyter_core==5.7.1 kiwisolver==1.4.5 lxml==5.1.0 Markdown==3.5.2 markdown-it-py==3.0.0 markdown2==2.4.12 MarkupSafe==2.1.3 matplotlib==3.8.2 matplotlib-inline==0.1.6 mbstrdecoder==1.1.3 mdurl==0.1.2 ml-dtypes==0.3.2 mpmath==1.3.0 msgpack==1.0.7 multidict==6.0.4 multiprocess==0.70.15 nest-asyncio==1.5.9 networkx==3.2.1 nh3==0.2.15 ninja==1.11.1.1 nltk==3.8.1 numexpr==2.8.8 numpy==1.26.3 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==8.9.2.26 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.18.1 nvidia-nvjitlink-cu12==12.3.101 nvidia-nvtx-cu12==12.1.105 oauthlib==3.2.2 openai==0.28.1 opt-einsum==3.3.0 optax==0.1.9 orbax-checkpoint==0.5.3 packaging==23.2 pandas==2.1.4 parso==0.8.3 pathvalidate==3.2.0 peft==0.7.1 pexpect==4.9.0 pillow==10.2.0 platformdirs==4.1.0 portalocker==2.8.2 prompt-toolkit==3.0.43 protobuf==3.20.2 psutil==5.9.7 ptyprocess==0.7.0 pure-eval==0.2.2 py-cpuinfo==9.0.0 pyarrow==14.0.2 pyasn1==0.5.1 pyasn1-modules==0.3.0 pybind11==2.11.1 pydantic==1.10.13 pydantic_core==2.14.6 Pygments==2.17.2 pynvml==11.5.0 pyparsing==3.1.1 pytablewriter==1.2.0 python-dateutil==2.8.2 pytz==2023.3.post1 PyYAML==6.0.1 pyzmq==25.1.2 ray==2.9.0 referencing==0.32.1 regex==2023.12.25 requests==2.31.0 requests-oauthlib==1.3.1 responses==0.18.0 rich==13.7.0 rouge-score==0.1.2 rpds-py==0.16.2 rsa==4.9 sacrebleu==2.4.0 safetensors==0.4.1 scikit-learn==1.3.2 scipy==1.11.4 seaborn==0.13.2 sentencepiece==0.1.99 sentry-sdk==1.39.2 setproctitle==1.3.3 shortuuid==1.0.11 shtab==1.6.5 six==1.16.0 smmap==5.0.1 sniffio==1.3.0 sqlitedict==2.1.0 stack-data==0.6.3 starlette==0.35.1 svgwrite==1.4.3 sympy==1.12 tabledata==1.3.3 tabulate==0.9.0 tcolorpy==0.1.4 tensorboard==2.15.1 tensorboard-data-server==0.7.2 tensorstore==0.1.53 threadpoolctl==3.2.0 tiktoken==0.5.2 tokenizers==0.15.0 toolz==0.12.1 torch==2.1.2 tornado==6.4 tqdm==4.66.1 tqdm-multiprocess==0.0.11 traitlets==5.14.1 transformers==4.36.2 triton==2.1.0 trl==0.7.7 typepy==1.3.2 typing_extensions==4.9.0 tyro==0.6.3 tzdata==2023.4 urllib3==2.1.0 uvicorn==0.25.0 wandb==0.16.2 wavedrom==2.0.3.post3 wcwidth==0.2.13 Werkzeug==3.0.1 xxhash==3.4.1 yarl==1.9.4 zipp==3.17.0 zstandard==0.22.0 ```State of changes before rebase on 0.2.9
After some minor changes to get things kicked off (eg missing import, typos, ...), we have that the ring attention transformer forwards pass hangs in
send_and_receive_
ondist.recv(receive_buffer, receive_from_rank)
, and we find that replacing the isend/irecv requests in send_andreceive with batched ones fixes this (see the comment below).With the batched request changed and turned on, the ring attention transformer fails in the flash attention backwards pass with:
When
striped_ring_attn=False
is set inassert_cuda.py
, we find that the ring backwards pass returns from_flash_attn_backward
succesfully, but hangs at an undetermined later point.State as of 5532b7b
both striped and non-striped ring attention get through the computation but are not close in terms of forward pass outputs.
However the gradients are close, with a max difference on the order of 0.0004