Closed jpfeil closed 1 year ago
I also ran into this issue, using A100 GPUs. My workaround was to bypass using Flash attention by commenting out the follow lines in "attend.py"
if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = Config(True, False, False)
Without this, it should default to "math or mem efficient attention", based on the print statement on the following lines. Training works with those lines commented out!
I'm investigating further but figured I'd share this in case it's helpful for anyone in the meantime 🫡
@jacobpfeil @timlenardo yeah, i'm going to remove all the manual checks
researchers are telling me that pytorch 2.1 flash attention works much more seamlessly
I'm trying to train the model on Imagenet, but I'm running into issues getting the model and data to fit in the GPU memory. I'm trying to use A100 gpus, but when the trainer runs I get this error:
I think this is related to this issue: https://github.com/lucidrains/x-transformers/issues/143
Is there a workaround for this issue?
Thank you!