FoundationVision / VAR

[GPT beats diffusion🔥] [scaling laws in visual generation📈] Official impl. of "Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction". An *ultra-simple, user-friendly yet state-of-the-art* codebase for autoregressive image generation!
MIT License
3.78k stars 285 forks source link

Dtype error with flash-attention #30

Closed ThisisBillhe closed 2 months ago

ThisisBillhe commented 2 months ago

This error occurs when sampling with pretrained model: "/xxx/VAR/models/basic_var.py", line 113, in forward oup = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=self.scale).view(B, L, C) RuntimeError: FlashAttention only support fp16 and bf16 data type.

The problem comes from that while qkv is initially fp16, the scale_mul in line 101 of basic_var.py is fp32, which makes q and k become fp32.

update: F.normalize(q, dim=-1) also changes the dtype of q to fp32.

keyu-tian commented 2 months ago

Thanks @ThisisBillhe. I fixed this bug now in 74138aa.