Open psmaragdis opened 1 month ago
Assigning @penpornk for the CPU part.
(The CUDA part probably should receive a look also, but the CPU problem is much worse. It probably means we're falling back to a naive implementation rather than using an optimized kernel.)
Description
Transpose convolutions are orders of magnitude slower than their complementary regular convolutions and their counterparts in torch (at least for the sizes in the example below). This problem is consistent across both cpu and cuda backends (so I wouldn't point a finger to CUDA here).
Notebook with timings on Colab is here: https://colab.research.google.com/drive/19g_VmTrK0bScC6p5sqbuND7n0FVi4GqW?usp=sharing
I'm also attaching a .py version of the code at the end, its output on my M1 laptop is:
And on an Ubuntu machine with an RTX4090:
Here is the standalone code. Change the
dev
parameter to either'cpu'
or'cuda'
accordingly.System info (python version, jaxlib version, accelerator, etc.)