Closed cccntu closed 1 year ago
Found the same problem when I tried to apply the flash attention to stable diffusion.
I met the problem when I use cuda:6
, and it cost me about one whole day to debug 💔.
Sorry about that, I've just reproduced it as well.
Must be because we're not setting the right device
somewhere.
Will try to figure out.
I've just pushed a commit that fixed this. Let me know if it works on your side.
I've just pushed a commit that fixed this. Let me know if it works on your side.
Hello, thank you for your attention. However, I tried the newest commit and found the problem is still there and even when I use cuda:0
device. The former version code (commit ID: 8166063a556e17e03e4a0697ba604def1eeb6a99) works well when using cuda:0
.
There are some information may be helpful:
I use flash_attention
for speeding up diffusers
and replace diffusers/models/attention.py
with the following code, as suggested in https://www.reddit.com/r/StableDiffusion/comments/xmr3ic/speed_up_stable_diffusion_by_50_using_flash/ and https://nn.labml.ai/diffusion/stable_diffusion/model/unet_attention.html#section-45
"""
---
title: Transformer for Stable Diffusion U-Net
summary: >
Annotated PyTorch implementation/tutorial of the transformer
for U-Net in stable diffusion.
---
# Transformer for Stable Diffusion [U-Net](unet.html)
This implements the transformer module used in [U-Net](unet.html) that
gives $\epsilon_\text{cond}(x_t, c)$
We have kept to the model definition and naming unchanged from
[CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion)
so that we can load the checkpoints directly.
"""
from typing import Optional
import os
import math
import torch
import torch.nn.functional as F
from torch import nn
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
httpS://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
Uses three q, k, v linear layers to compute attention.
Parameters:
channels (:obj:`int`): The number of channels in the input and output.
num_head_channels (:obj:`int`, *optional*):
The number of channels in each head. If None, then `num_heads` = 1.
num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
"""
def __init__(
self,
channels: int,
num_head_channels: Optional[int] = None,
num_groups: int = 32,
rescale_output_factor: float = 1.0,
eps: float = 1e-5,
):
super().__init__()
self.channels = channels
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
self.num_head_size = num_head_channels
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
# define q,k,v as linear layers
self.query = nn.Linear(channels, channels)
self.key = nn.Linear(channels, channels)
self.value = nn.Linear(channels, channels)
self.rescale_output_factor = rescale_output_factor
self.proj_attn = nn.Linear(channels, channels, 1)
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
return new_projection
def forward(self, hidden_states):
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)
# transpose
query_states = self.transpose_for_scores(query_proj)
key_states = self.transpose_for_scores(key_proj)
value_states = self.transpose_for_scores(value_proj)
# get scores
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
# compute attention output
hidden_states = torch.matmul(attention_probs, value_states)
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
hidden_states = hidden_states.view(new_hidden_states_shape)
# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
class SpatialTransformer(nn.Module):
"""
## Spatial Transformer
"""
def __init__(self, channels: int, n_heads: int, d_cond: int, depth: int, num_groups=32, context_dim=None):
"""
:param channels: is the number of channels in the feature map
:param n_heads: is the number of attention heads
:param n_layers: is the number of transformer layers
:param d_cond: is the size of the conditional embedding
"""
super().__init__()
n_layers = depth
d_cond = context_dim
# Initial group normalization
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True)
# Initial $1 \times 1$ convolution
self.proj_in = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0)
# Transformer layers
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(channels, n_heads, channels // n_heads, d_cond=d_cond) for _ in range(n_layers)]
)
# Final $1 \times 1$ convolution
self.proj_out = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0)
def forward(self, hidden_states: torch.Tensor, context: torch.Tensor):
"""
:param x: is the feature map of shape `[batch_size, channels, height, width]`
:param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
"""
# Get shape `[batch_size, channels, height, width]`
x = hidden_states
cond = context
b, c, h, w = x.shape
# For residual connection
x_in = x
# Normalize
x = self.norm(x)
# Initial $1 \times 1$ convolution
x = self.proj_in(x)
# Transpose and reshape from `[batch_size, channels, height, width]`
# to `[batch_size, height * width, channels]`
x = x.permute(0, 2, 3, 1).view(b, h * w, c)
# Apply the transformer layers
for block in self.transformer_blocks:
x = block(x, cond)
# Reshape and transpose from `[batch_size, height * width, channels]`
# to `[batch_size, channels, height, width]`
x = x.view(b, h, w, c).permute(0, 3, 1, 2)
# Final $1 \times 1$ convolution
x = self.proj_out(x)
# Add residual
return x + x_in
class BasicTransformerBlock(nn.Module):
"""
### Transformer Layer
"""
def __init__(self, d_model: int, n_heads: int, d_head: int, d_cond: int):
"""
:param d_model: is the input embedding size
:param n_heads: is the number of attention heads
:param d_head: is the size of a attention head
:param d_cond: is the size of the conditional embeddings
"""
super().__init__()
# Self-attention layer and pre-norm layer
self.attn1 = CrossAttention(d_model, d_model, n_heads, d_head)
self.norm1 = nn.LayerNorm(d_model)
# Cross attention layer and pre-norm layer
self.attn2 = CrossAttention(d_model, d_cond, n_heads, d_head)
self.norm2 = nn.LayerNorm(d_model)
# Feed-forward network and pre-norm layer
self.ff = FeedForward(d_model)
self.norm3 = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor, cond: torch.Tensor):
"""
:param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
:param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
"""
# Self attention
x = self.attn1(self.norm1(x)) + x
# Cross-attention with conditioning
x = self.attn2(self.norm2(x), cond=cond) + x
# Feed-forward network
x = self.ff(self.norm3(x)) + x
#
return x
class CrossAttention(nn.Module):
"""
### Cross Attention Layer
This falls-back to self-attention when conditional embeddings are not specified.
"""
use_flash_attention: bool = int(os.environ.get("USE_FLASH_ATTENTION", 0))==1
use_flash_attention = True
def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True):
"""
:param d_model: is the input embedding size
:param n_heads: is the number of attention heads
:param d_head: is the size of a attention head
:param d_cond: is the size of the conditional embeddings
:param is_inplace: specifies whether to perform the attention softmax computation inplace to
save memory
"""
super().__init__()
self.is_inplace = is_inplace
self.n_heads = n_heads
self.d_head = d_head
# Attention scaling factor
self.scale = d_head ** -0.5
# Query, key and value mappings
d_attn = d_head * n_heads
self.to_q = nn.Linear(d_model, d_attn, bias=False)
self.to_k = nn.Linear(d_cond, d_attn, bias=False)
self.to_v = nn.Linear(d_cond, d_attn, bias=False)
# Final linear layer
self.to_out = nn.Sequential(nn.Linear(d_attn, d_model))
# Setup [flash attention](https://github.com/HazyResearch/flash-attention).
# Flash attention is only used if it's installed
# and `CrossAttention.use_flash_attention` is set to `True`.
try:
# You can install flash attention by cloning their Github repo,
# [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)
# and then running `python setup.py install`
from flash_attn.flash_attention import FlashAttention
self.flash = FlashAttention()
# Set the scale for scaled dot-product attention.
self.flash.softmax_scale = self.scale
# Set to `None` if it's not installed
except ImportError:
self.flash = None
def forward(self, x: torch.Tensor, cond: Optional[torch.Tensor] = None):
"""
:param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
:param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
"""
# If `cond` is `None` we perform self attention
has_cond = cond is not None
if not has_cond:
cond = x
# Get query, key and value vectors
q = self.to_q(x)
k = self.to_k(cond)
v = self.to_v(cond)
# Use flash attention if it's available and the head size is less than or equal to `128`
if CrossAttention.use_flash_attention and self.flash is not None and not has_cond and self.d_head <= 128:
return self.flash_attention(q, k, v)
# Otherwise, fallback to normal attention
else:
return self.normal_attention(q, k, v)
def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""
#### Flash Attention
:param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
:param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
:param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
"""
# Get batch size and number of elements along sequence axis (`width * height`)
batch_size, seq_len, _ = q.shape
# Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of
# shape `[batch_size, seq_len, 3, n_heads * d_head]`
qkv = torch.stack((q, k, v), dim=2)
# Split the heads
qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
# Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to
# fit this size.
if self.d_head <= 32:
pad = 32 - self.d_head
elif self.d_head <= 64:
pad = 64 - self.d_head
elif self.d_head <= 128:
pad = 128 - self.d_head
else:
raise ValueError(f'Head size ${self.d_head} too large for Flash Attention')
# Pad the heads
if pad:
qkv = torch.cat((qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1)
# Compute attention
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
# This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]`
out, _ = self.flash(qkv)
# Truncate the extra head size
out = out[:, :, :, :self.d_head]
# Reshape to `[batch_size, seq_len, n_heads * d_head]`
out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
# Map to `[batch_size, height * width, d_model]` with a linear layer
return self.to_out(out)
def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""
#### Normal Attention
:param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
:param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
:param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
"""
# Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
q = q.view(*q.shape[:2], self.n_heads, -1)
k = k.view(*k.shape[:2], self.n_heads, -1)
v = v.view(*v.shape[:2], self.n_heads, -1)
# Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
attn = torch.einsum('bihd,bjhd->bhij', q, k) * self.scale
# Compute softmax
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
if self.is_inplace:
half = attn.shape[0] // 2
attn[half:] = attn[half:].softmax(dim=-1)
attn[:half] = attn[:half].softmax(dim=-1)
else:
attn = attn.softmax(dim=-1)
# Compute attention output
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
out = torch.einsum('bhij,bjhd->bihd', attn, v)
# Reshape to `[batch_size, height * width, n_heads * d_head]`
out = out.reshape(*out.shape[:2], -1)
# Map to `[batch_size, height * width, d_model]` with a linear layer
return self.to_out(out)
class FeedForward(nn.Module):
"""
### Feed-Forward Network
"""
def __init__(self, d_model: int, d_mult: int = 4):
"""
:param d_model: is the input embedding size
:param d_mult: is multiplicative factor for the hidden layer size
"""
super().__init__()
self.net = nn.Sequential(
GeGLU(d_model, d_model * d_mult),
nn.Dropout(0.),
nn.Linear(d_model * d_mult, d_model)
)
def forward(self, x: torch.Tensor):
return self.net(x)
class GeGLU(nn.Module):
"""
### GeGLU Activation
$$\text{GeGLU}(x) = (xW + b) * \text{GELU}(xV + c)$$
"""
def __init__(self, d_in: int, d_out: int):
super().__init__()
# Combined linear projections $xW + b$ and $xV + c$
self.proj = nn.Linear(d_in, d_out * 2)
def forward(self, x: torch.Tensor):
# Get $xW + b$ and $xV + c$
x, gate = self.proj(x).chunk(2, dim=-1)
# $\text{GeGLU}(x) = (xW + b) * \text{GELU}(xV + c)$
return x * F.gelu(gate)
Let's try to figure this out.
Can you try recompiling FlashAttention?
pip uninstall -y flash_attn && rm -rf build && python setup.py install
.
Another thing to try is to put torch.cuda.set_device(qkv.device())
before calling self.flash
.
Lmk if any of those help.
@geekinglcq are you on the cutlass
branch of FlashAttention? I've also just ported the (same) fix to that branch.
Let's try to figure this out.
- Can you try recompiling FlashAttention?
pip uninstall -y flash_attn && rm -rf build && python setup.py install
.- Another thing to try is to put
torch.cuda.set_device(qkv.device())
before callingself.flash
.Lmk if any of those help.
It works now! Thank you~
Great!
@geekinglcq Did recompilation fix it or do you need to set torch.cuda.set_device(qkv.device())
?
Great! @geekinglcq Did recompilation fix it or do you need to set
torch.cuda.set_device(qkv.device())
?
I just recompiled it.
testing code:
if I add the following code and run it,
the output is
if I run this instead
the output is
I've verified it's not the hardware issue by adding
at the top and it works with
cuda:0