pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.15k stars 22.08k forks source link

bitwise_xor seems to be not optimized #132612

Open esason opened 1 month ago

esason commented 1 month ago

🚀 The feature, motivation and pitch

Recently, I have analyzed time complexity a of simple matrix maltiplication to "matrix_xor". (this is a kind of approximation that i am doing, but it is not the point), the point is the time complexity analyze itseft.

*The matrix opeation is: (1) B, H, L, dim = 5, 12, 1024, 64 (2) x1 = torch.randn((B, H, L,dim), device='cuda') (3) x2 = torch.randn((B, H, L, dim), device='cuda') (4) res1 = torch.matmal(x1, x2.traspose(-2,-1))

with this operation each entry has about mul&add of floats O(dim)

*The "matrix xor" opeation is: I compared (4) t to a case where each entry is single xor operations. that is: (5) x1 = torch.randint(2 32, (B, H, L,1), device='cuda') (6) x2 = torch.randint(2 32, (B, H, 1, L), device='cuda') (7) res2 = torch.bitwise_xor(x1,x2)

Each entry has one operation of xor (int)

Comparing time complexity only of (4) vs (7), show very small advantage to the bitwise_xor. Although this is single xor vs #dim of mul&add!

I thought that it should to be much faster with about ~dim, but it gives my very thin speedup of 1.5 which is very dissapointing. It seems that the bitwise_xor within torch is not optimize, and it can be good to optimize. thanks!

Alternatives

No response

Additional context

No response

cc @msaroufim @albanD

albanD commented 1 month ago

I definitely expect mm was studied and optimized a lot more that xor, in particular, there might be vectorization we're missing.

with this operation each entry has about mul&add of floats O(dim)

This is not fully true though, matmul O complexity is much lower than n^3 in practice. Also keep in mind that, on gpu, most of these ops are purely memory-bandwidth bound and not compute bound. And both these ops pull ~the same amount of data.

esason commented 1 month ago

hi @albanD I made another check, in order to debug/get insight the problem, I focus only on the matmul and compare between some cases dim=1, dim=64 and dim=128. The rational is that the xor in the second case, at best should be like the matmul where dim=1.

I got speedup is 1.5 faster when dim=1, which is also suprising in some sense. when dim=128 (seedup of ~2) so, this same question stands for the matmul itself, why the speedup when dim=1 is not much faster ....