lucidrains / PaLM-rlhf-pytorch

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM
MIT License
7.71k stars 669 forks source link

i use other params with palm, but got error #45

Closed wac81 closed 1 year ago

wac81 commented 1 year ago

model = PaLM( num_tokens=256, #512 1024 dim=2048, #dim_head*heads depth=24, dim_head = 256, #always 256 heads = 8, flash_attn=True ).to(device)

wac81 commented 1 year ago

training: 0%| | 0/2000000 [00:00<?, ?it/s] Traceback (most recent call last): File "train_qa_webtext2.py", line 164, in accelerator.backward(loss / GRADIENT_ACCUMULATE_EVERY) File "/home/wac/anaconda3/envs/py38/lib/python3.8/site-packages/accelerate/accelerator.py", line 1683, in backward loss.backward(**kwargs) File "/home/wac/anaconda3/envs/py38/lib/python3.8/site-packages/torch/_tensor.py", line 487, in backward torch.autograd.backward( File "/home/wac/anaconda3/envs/py38/lib/python3.8/site-packages/torch/autograd/init.py", line 200, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass RuntimeError: CUDA error: invalid argument CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1. Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

conceptofmind commented 1 year ago

model = PaLM( num_tokens=256, #512 1024 dim=2048, #dim_head*heads depth=24, dim_head = 256, #always 256 heads = 8, flash_attn=True ).to(device)

Which type of GPU are you using? Are you using PyTorch 2.0? Flash Attention requires an A100. Also, I do not believe Flash Attention supports dim_head larger than 128.

FlashAttention currently supports:
1. Turing, Ampere, Ada, or Hopper GPUs (e.g., A100).
2. fp16 and bf16
3. Head dimensions that are multiples of 8, up to 128 (e.g., 8, 16, 24, ..., 128). Head dim > 64 backward requires A100 or H100.
wac81 commented 1 year ago

use a6000,can‘t open FA? how to setup dim_head larger than 256?

wac81 commented 1 year ago

runs well with 256 dim_head while i comments FA