Closed conceptofmind closed 1 year ago
I also was not sure whether you want the attention mechanisms separated into two different functions or if just an if/else
statement would be sufficient.
@conceptofmind oh thanks for doing this! so actually my plan had been to just wait for pytorch 2.0 official release
flash attention will be available as a part of the functional package, probably the same code that was used to train llama
hopefully it is soon..
@conceptofmind oh thanks for doing this! so actually my plan had been to just wait for pytorch 2.0 official release
flash attention will be available as a part of the functional package, probably the same code that was used to train llama
hopefully it is soon..
Ok. Awesome to know that it will be directly integrated with PyTorch now! Do you want me to close this PR or leave it open as a reference for others until they update the new 2.0 official release? Or I can update it when the new version is released. I am checking out the nightly now.
Side note, I am going to work on pre-training some small PaLM models on CC4, or something similar, that are compatible with this repo.
@conceptofmind you can leave it open, as pytorch 2.0 just got released this week with flash attention support! i'll probably try this weekend to upgrade to cuda 11.7 and to pytorch 2.0 and see if i can get it integrated
but if you want to give it a try, i would definitely welcome the contribution!
@lucidrains I updated to use the PyTorch 2.0 version of Flash Attention.
Following the documentation:
torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
F.scaled_dot_product_attention(query,key,value)
Confirmed with Patrick from PyTorch: https://discuss.pytorch.org/t/flash-attention/174955
It trained for me with CUDA 11.7. I would just do a sanity check though since it is still so new.
Thank you,
Enrico
@conceptofmind wow, thank you Enrico! this looks great :pray:
do you want to add a citation for flash attention, and also do a version check to make sure pytorch is 2.0 before flash attention can be enabled?
@conceptofmind haha, i'm upgrading to pytorch 2.0 this morning, wish me luck (need to upgrade my CUDA from 11.4 too)
@conceptofmind one more thing to test is whether their context can naturally fallback to cpu, if cuda is not found
@lucidrains
@conceptofmind wow, thank you Enrico! this looks great pray
do you want to add a citation for flash attention, and also do a version check to make sure pytorch is 2.0 before flash attention can be enabled?
I will add the citation for Flash Attention. Is this what you imagined for an assertion:
if self.flash_attn:
try:
assert torch.__version__ >= "2.0.0"
except:
raise Exception("flash attention requires pytorch 2.0")
@conceptofmind haha, i'm upgrading to pytorch 2.0 this morning, wish me luck (need to upgrade my CUDA from 11.4 too)
I upgraded from 11.3 to 11.7. Always fun installing new CUDA versions and having to scan through 10s of files to make sure there are no residual conflicting artifacts from previous installations :laughing:
@conceptofmind one more thing to test is whether their context can naturally fallback to cpu, if cuda is not found
I believe the way the context manager is set up is that it checks what is available from the three and defaults to the best option. I had hardcoded it to use FlashAttention to ensure that is was used. I can leave it to the default as well:
torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
F.scaled_dot_product_attention(query,key,value)
Or I can add some sort of filter if an A100 is available:
if enable_flash == True:
enable_math = False
enable_mem_efficient = False
else:
enable_math = True
enable_mem_efficient = True
One thing to note is that PyTorch 2.0 Flash Attention seems to throw a kernel error if a compatible GPU architecture is not available.
RuntimeError: No available kernel. Aborting execution.
Should we include something that checks for an A100:
try:
if torch.cuda.is_available():
device_properties = torch.cuda.get_device_properties(device)
if device_properties.major == 8 and device_properties.minor == 0:
enable_flash = True
else:
enable_flash = False
except RuntimeError as error:
print(f'An error occurred: {error}.')
Let me know what you think is best!
Thank you,
Enrico
yup exactly! but you would do from packaging import version
, and then version.parse(torch.__version__) >= version.parse('2.0.0')
, as string comparison would have a different behavior.
and yes, try catches would be perfect for if the kernel selection heuristic fails
@lucidrains
yup exactly! but you would do
from packaging import version
, and thenversion.parse(torch.__version__) >= version.parse('2.0.0')
, as string comparison would have a different behavior.and yes, try catches would be perfect for if the kernel selection heuristic fails
Well I think the kernel selection works. :laughing:
# Check if there is a compatible device for flash attention
try:
flash_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if flash_device.type == 'cuda':
device_properties = torch.cuda.get_device_properties(device)
if device_properties.major == 8 and device_properties.minor == 0:
print('A100 GPU detected, using flash attention')
enable_flash = True
enable_math = False
enable_mem_efficient = False
else:
print('Non-A100 GPU detected, using math or mem efficient attention')
enable_flash = False
enable_math = True
enable_mem_efficient = True
else:
# Default context manager settings with CPU
print('CPU detected, using default context manager settings')
enable_flash = True
enable_math = True
enable_mem_efficient = True
except RuntimeError as error:
print(f'An error occurred: {error}.')
Pushed the other updates as well.
Hopefully, everything looks better now.
Thank you,
Enrico
@conceptofmind yes, almost there! left two comments
Updated to include your recommendation for masking and additional logic.
I tried training and ran three forward/backward passes. The loss went down.
No Flash
import torch
from palm_rlhf_pytorch import PaLM
palm = PaLM(
num_tokens = 20000,
dim = 512,
depth = 12,
flash_attn=False,
).cuda()
seq = torch.randint(0, 20000, (1, 2048)).cuda()
loss = palm(seq, return_loss = True)
loss.backward()
print(loss)
Flash
import torch
from palm_rlhf_pytorch import PaLM
palm = PaLM(
num_tokens = 20000,
dim = 512,
depth = 12,
flash_attn=True,
).cuda()
seq = torch.randint(0, 20000, (1, 2048)).cuda()
loss = palm(seq, return_loss = True)
loss.backward()
print(loss)
CPU
import torch
from palm_rlhf_pytorch import PaLM
palm = PaLM(
num_tokens = 20000,
dim = 512,
depth = 12,
)
seq = torch.randint(0, 20000, (1, 2048))
loss = palm(seq, return_loss = True)
loss.backward()
print(loss)
Nice! Want to also test it with non-causal with key padding mask?
Nice! Want to also test it with non-causal with key padding mask?
Will start writing some tests for non-causal now.
I tried with these configurations for causal=False
:
No Flash
import torch
from palm_rlhf_pytorch import PaLM
palm = PaLM(
num_tokens = 20000,
dim = 512,
depth = 12,
causal = False,
flash_attn = False,
)
seq = torch.randint(0, 20000, (1, 2048))
loss = palm(seq, return_loss = True)
loss.backward()
print(loss)
Flash
import torch
from palm_rlhf_pytorch import PaLM
palm = PaLM(
num_tokens = 20000,
dim = 512,
depth = 12,
causal = False,
flash_attn = True,
)
seq = torch.randint(0, 20000, (1, 2048))
loss = palm(seq, return_loss = True)
loss.backward()
print(loss)
Flash
import torch
from palm_rlhf_pytorch import PaLM, RewardModel
palm = PaLM(
num_tokens = 20000,
dim = 512,
depth = 12,
causal = False,
flash_attn = True,
)
reward_model = RewardModel(
palm,
num_binned_output = 5 # say rating from 1 to 5
).cuda()
# mock data
seq = torch.randint(0, 20000, (1, 1024)).cuda()
prompt_mask = torch.zeros(1, 1024).bool().cuda() # which part of the sequence is prompt, which part is response
labels = torch.randint(0, 5, (1,)).cuda()
# train
loss = reward_model(seq, prompt_mask = prompt_mask, labels = labels)
loss.backward()
print(loss)
No Flash
import torch
from palm_rlhf_pytorch import PaLM, RewardModel
palm = PaLM(
num_tokens = 20000,
dim = 512,
depth = 12,
causal = False,
flash_attn = False,
)
reward_model = RewardModel(
palm,
num_binned_output = 5 # say rating from 1 to 5
).cuda()
# mock data
seq = torch.randint(0, 20000, (1, 1024)).cuda()
prompt_mask = torch.zeros(1, 1024).bool().cuda() # which part of the sequence is prompt, which part is response
labels = torch.randint(0, 5, (1,)).cuda()
# train
loss = reward_model(seq, prompt_mask = prompt_mask, labels = labels)
loss.backward()
print(loss)
I used the PyTorch version for expanding the mask you provided to keep it consistent.
Hopefully not missing anything.
Thank you,
Enrico
@conceptofmind graciously accept! thank you for all your work on this PR 🙏
Hi Phil,
This PR adds support for Flash Attention in Triton. Tri Dao had given me some input on how to properly handle the tensor shapes for multi-query single-key-value attention with Flash Attention.
I was unsure which version of Flash Attention you wanted to add. I can work on changing it to the CUDA version if that is preferred.
Thank you,
Enrico