lucidrains / PaLM-rlhf-pytorch

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM
MIT License
7.7k stars 666 forks source link

Add PyTorch 2.0 Flash Attention #40

Closed conceptofmind closed 1 year ago

conceptofmind commented 1 year ago

Hi Phil,

This PR adds support for Flash Attention in Triton. Tri Dao had given me some input on how to properly handle the tensor shapes for multi-query single-key-value attention with Flash Attention.

I was unsure which version of Flash Attention you wanted to add. I can work on changing it to the CUDA version if that is preferred.

import torch
from palm_rlhf_pytorch import PaLM

palm = PaLM(
    num_tokens = 20000,
    dim = 512,
    depth = 12,
    flash_attn=True,
).half().cuda()

seq = torch.randint(0, 20000, (1, 2048)).cuda()

loss = palm(seq, return_loss = True)
print(loss)

Thank you,

Enrico

conceptofmind commented 1 year ago

I also was not sure whether you want the attention mechanisms separated into two different functions or if just an if/else statement would be sufficient.

lucidrains commented 1 year ago

@conceptofmind oh thanks for doing this! so actually my plan had been to just wait for pytorch 2.0 official release

flash attention will be available as a part of the functional package, probably the same code that was used to train llama

hopefully it is soon..

conceptofmind commented 1 year ago

@conceptofmind oh thanks for doing this! so actually my plan had been to just wait for pytorch 2.0 official release

flash attention will be available as a part of the functional package, probably the same code that was used to train llama

hopefully it is soon..

Ok. Awesome to know that it will be directly integrated with PyTorch now! Do you want me to close this PR or leave it open as a reference for others until they update the new 2.0 official release? Or I can update it when the new version is released. I am checking out the nightly now.

Side note, I am going to work on pre-training some small PaLM models on CC4, or something similar, that are compatible with this repo.

lucidrains commented 1 year ago

@conceptofmind you can leave it open, as pytorch 2.0 just got released this week with flash attention support! i'll probably try this weekend to upgrade to cuda 11.7 and to pytorch 2.0 and see if i can get it integrated

but if you want to give it a try, i would definitely welcome the contribution!

conceptofmind commented 1 year ago

@lucidrains I updated to use the PyTorch 2.0 version of Flash Attention.

Following the documentation:

Confirmed with Patrick from PyTorch: https://discuss.pytorch.org/t/flash-attention/174955

It trained for me with CUDA 11.7. I would just do a sanity check though since it is still so new.

Thank you,

Enrico

lucidrains commented 1 year ago

@conceptofmind wow, thank you Enrico! this looks great :pray:

do you want to add a citation for flash attention, and also do a version check to make sure pytorch is 2.0 before flash attention can be enabled?

lucidrains commented 1 year ago

@conceptofmind haha, i'm upgrading to pytorch 2.0 this morning, wish me luck (need to upgrade my CUDA from 11.4 too)

lucidrains commented 1 year ago

@conceptofmind one more thing to test is whether their context can naturally fallback to cpu, if cuda is not found

conceptofmind commented 1 year ago

@lucidrains

@conceptofmind wow, thank you Enrico! this looks great pray

do you want to add a citation for flash attention, and also do a version check to make sure pytorch is 2.0 before flash attention can be enabled?

I will add the citation for Flash Attention. Is this what you imagined for an assertion:

if self.flash_attn:
  try:
    assert torch.__version__ >= "2.0.0"
  except:
    raise Exception("flash attention requires pytorch 2.0")

@conceptofmind haha, i'm upgrading to pytorch 2.0 this morning, wish me luck (need to upgrade my CUDA from 11.4 too)

I upgraded from 11.3 to 11.7. Always fun installing new CUDA versions and having to scan through 10s of files to make sure there are no residual conflicting artifacts from previous installations :laughing:

@conceptofmind one more thing to test is whether their context can naturally fallback to cpu, if cuda is not found

I believe the way the context manager is set up is that it checks what is available from the three and defaults to the best option. I had hardcoded it to use FlashAttention to ensure that is was used. I can leave it to the default as well:

torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
    F.scaled_dot_product_attention(query,key,value)

Or I can add some sort of filter if an A100 is available:

if enable_flash == True:
    enable_math = False
    enable_mem_efficient = False
else:
    enable_math = True
    enable_mem_efficient = True

One thing to note is that PyTorch 2.0 Flash Attention seems to throw a kernel error if a compatible GPU architecture is not available.

RuntimeError: No available kernel.  Aborting execution.

Should we include something that checks for an A100:

try:
    if torch.cuda.is_available():
        device_properties = torch.cuda.get_device_properties(device)
        if device_properties.major == 8 and device_properties.minor == 0:
            enable_flash = True
        else:
            enable_flash = False
except RuntimeError as error:
    print(f'An error occurred: {error}.')

Let me know what you think is best!

Thank you,

Enrico

lucidrains commented 1 year ago

yup exactly! but you would do from packaging import version, and then version.parse(torch.__version__) >= version.parse('2.0.0'), as string comparison would have a different behavior.

and yes, try catches would be perfect for if the kernel selection heuristic fails

conceptofmind commented 1 year ago

@lucidrains

yup exactly! but you would do from packaging import version, and then version.parse(torch.__version__) >= version.parse('2.0.0'), as string comparison would have a different behavior.

and yes, try catches would be perfect for if the kernel selection heuristic fails

Well I think the kernel selection works. :laughing:

# Check if there is a compatible device for flash attention

try:
    flash_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if flash_device.type == 'cuda':
        device_properties = torch.cuda.get_device_properties(device)
        if device_properties.major == 8 and device_properties.minor == 0:
            print('A100 GPU detected, using flash attention')
            enable_flash = True
            enable_math = False
            enable_mem_efficient = False
        else:
            print('Non-A100 GPU detected, using math or mem efficient attention')
            enable_flash = False
            enable_math = True
            enable_mem_efficient = True
    else:
        # Default context manager settings with CPU
        print('CPU detected, using default context manager settings')
        enable_flash = True
        enable_math = True
        enable_mem_efficient = True
except RuntimeError as error:
    print(f'An error occurred: {error}.')

Pushed the other updates as well.

Hopefully, everything looks better now.

Thank you,

Enrico

lucidrains commented 1 year ago

@conceptofmind yes, almost there! left two comments

conceptofmind commented 1 year ago

Updated to include your recommendation for masking and additional logic.

I tried training and ran three forward/backward passes. The loss went down.

No Flash

import torch
from palm_rlhf_pytorch import PaLM

palm = PaLM(
    num_tokens = 20000,
    dim = 512,
    depth = 12,
    flash_attn=False,
).cuda()

seq = torch.randint(0, 20000, (1, 2048)).cuda()

loss = palm(seq, return_loss = True)
loss.backward()
print(loss)

Flash

import torch
from palm_rlhf_pytorch import PaLM

palm = PaLM(
    num_tokens = 20000,
    dim = 512,
    depth = 12,
    flash_attn=True,
).cuda()

seq = torch.randint(0, 20000, (1, 2048)).cuda()

loss = palm(seq, return_loss = True)
loss.backward()
print(loss)

CPU

import torch
from palm_rlhf_pytorch import PaLM

palm = PaLM(
    num_tokens = 20000,
    dim = 512,
    depth = 12,
)

seq = torch.randint(0, 20000, (1, 2048))

loss = palm(seq, return_loss = True)
loss.backward()
print(loss)
lucidrains commented 1 year ago

Nice! Want to also test it with non-causal with key padding mask?

conceptofmind commented 1 year ago

Nice! Want to also test it with non-causal with key padding mask?

Will start writing some tests for non-causal now.

conceptofmind commented 1 year ago

I tried with these configurations for causal=False:

No Flash

import torch
from palm_rlhf_pytorch import PaLM

palm = PaLM(
    num_tokens = 20000,
    dim = 512,
    depth = 12,
    causal = False,
    flash_attn = False,
)

seq = torch.randint(0, 20000, (1, 2048))

loss = palm(seq, return_loss = True)
loss.backward()
print(loss)

Flash

import torch
from palm_rlhf_pytorch import PaLM

palm = PaLM(
    num_tokens = 20000,
    dim = 512,
    depth = 12,
    causal = False,
    flash_attn = True,
)

seq = torch.randint(0, 20000, (1, 2048))

loss = palm(seq, return_loss = True)
loss.backward()
print(loss)

Flash

import torch
from palm_rlhf_pytorch import PaLM, RewardModel

palm = PaLM(
    num_tokens = 20000,
    dim = 512,
    depth = 12,
    causal = False,
    flash_attn = True,
)

reward_model = RewardModel(
    palm,
    num_binned_output = 5 # say rating from 1 to 5
).cuda()

# mock data

seq = torch.randint(0, 20000, (1, 1024)).cuda()
prompt_mask = torch.zeros(1, 1024).bool().cuda() # which part of the sequence is prompt, which part is response
labels = torch.randint(0, 5, (1,)).cuda()

# train

loss = reward_model(seq, prompt_mask = prompt_mask, labels = labels)
loss.backward()
print(loss)

No Flash

import torch
from palm_rlhf_pytorch import PaLM, RewardModel

palm = PaLM(
    num_tokens = 20000,
    dim = 512,
    depth = 12,
    causal = False,
    flash_attn = False,
)

reward_model = RewardModel(
    palm,
    num_binned_output = 5 # say rating from 1 to 5
).cuda()

# mock data

seq = torch.randint(0, 20000, (1, 1024)).cuda()
prompt_mask = torch.zeros(1, 1024).bool().cuda() # which part of the sequence is prompt, which part is response
labels = torch.randint(0, 5, (1,)).cuda()

# train

loss = reward_model(seq, prompt_mask = prompt_mask, labels = labels)
loss.backward()
print(loss)

I used the PyTorch version for expanding the mask you provided to keep it consistent.

Hopefully not missing anything.

Thank you,

Enrico

lucidrains commented 1 year ago

@conceptofmind graciously accept! thank you for all your work on this PR 🙏