Closed wac81 closed 1 year ago
training: 0%| | 0/2000000 [00:00<?, ?it/s]
Traceback (most recent call last):
File "train_qa_webtext2.py", line 164, in TORCH_USE_CUDA_DSA
to enable device-side assertions.
model = PaLM( num_tokens=256, #512 1024 dim=2048, #dim_head*heads depth=24, dim_head = 256, #always 256 heads = 8, flash_attn=True ).to(device)
Which type of GPU are you using? Are you using PyTorch 2.0? Flash Attention requires an A100. Also, I do not believe Flash Attention supports dim_head
larger than 128.
FlashAttention currently supports:
1. Turing, Ampere, Ada, or Hopper GPUs (e.g., A100).
2. fp16 and bf16
3. Head dimensions that are multiples of 8, up to 128 (e.g., 8, 16, 24, ..., 128). Head dim > 64 backward requires A100 or H100.
use a6000,can‘t open FA? how to setup dim_head larger than 256?
runs well with 256 dim_head while i comments FA
model = PaLM( num_tokens=256, #512 1024 dim=2048, #dim_head*heads depth=24, dim_head = 256, #always 256 heads = 8, flash_attn=True ).to(device)