Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.67k stars 1.25k forks source link

Speeding up exp with lookup tables? #1261

Open ethansmith2000 opened 14 hours ago

ethansmith2000 commented 14 hours ago

Hi, I saw this recent post about FA-3 by pytorch and noticed this bit GZU3x7DasAM2Iif

Something I had been curious about for a while is how many values computed by exp() in bfloat16 go to out of bounds numbers or the same numbers. namely of the 65536 possible input values, only 2267 map to values that are not 0, 1, or inf, and of course these are all in predictable segments < or > than some value

knowing this is it possible to speed up computation by using a simple lookup table? I understand memory is a precious resource so this may backfire but was curious if this at all makes any sense.

Fp8 would let us slim this table down even more, (though i know sometimes ops are cast to higher precision so im not really sure)

tridao commented 13 hours ago

That's a good idea! I haven't tried it but it could potentially speed things up. The challenge might be that multiple threads indexing into the lookup table can cause bank conflicts. It's not immediately clear if this would be faster or slower than calling the exponential function.

SonicCodes commented 9 hours ago
Screenshot 2024-10-08 at 8 12 19 in the morning

tried a bunch of optims here and there, there's a consistent 10% slowdown from expf and using shared memory LUT (1024) on fp32 ,

But could see a way we can utilize registers as it seems that you can get away with 2x speed up on expf if you can put in as well codebook in there, loading and unloading takes time, that means you have to have your threads doing longer computation for the speedup, but seems interesting enough :)