Open ethansmith2000 opened 1 month ago
That's a good idea! I haven't tried it but it could potentially speed things up. The challenge might be that multiple threads indexing into the lookup table can cause bank conflicts. It's not immediately clear if this would be faster or slower than calling the exponential function.
tried a bunch of optims here and there, there's a consistent 10% slowdown from expf and using shared memory LUT (1024) on fp32 ,
But could see a way we can utilize registers as it seems that you can get away with 2x speed up on expf if you can put in as well codebook in there, loading and unloading takes time, that means you have to have your threads doing longer computation for the speedup, but seems interesting enough :)
A ratio of just 512 slower than tensor ops is actually pretty good. 1 clock is 1024 ops in a 32x32 MatMul. The lookup table in ROM would be around 250 sq um, small enough to have one per CUDA core but way too enormous to be in a MatMul node. So that speed ratio just reflects the beauty of the MatMul, not a flaw in the exponent operator.
3.9 TOps of special functions divided by 17,424 CUDA cores is around 2.2 GOps per core, which is one per clock. Probably pipelined but able to launch 1 special op every clock per tiny core is damn good.
Hi, I saw this recent post about FA-3 by pytorch and noticed this bit
Something I had been curious about for a while is how many values computed by exp() in bfloat16 go to out of bounds numbers or the same numbers. namely of the 65536 possible input values, only 2267 map to values that are not 0, 1, or inf, and of course these are all in predictable segments < or > than some value
knowing this is it possible to speed up computation by using a simple lookup table? I understand memory is a precious resource so this may backfire but was curious if this at all makes any sense.
Fp8 would let us slim this table down even more, (though i know sometimes ops are cast to higher precision so im not really sure)