jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.44k stars 2.8k forks source link

Slow transpose convolutions (both cpu and cuda backends) #23783

Open psmaragdis opened 1 month ago

psmaragdis commented 1 month ago

Description

Transpose convolutions are orders of magnitude slower than their complementary regular convolutions and their counterparts in torch (at least for the sizes in the example below). This problem is consistent across both cpu and cuda backends (so I wouldn't point a finger to CUDA here).

Notebook with timings on Colab is here: https://colab.research.google.com/drive/19g_VmTrK0bScC6p5sqbuND7n0FVi4GqW?usp=sharing

I'm also attaching a .py version of the code at the end, its output on my M1 laptop is:

Using jax 0.4.33 on cpu
Jax 1d conv :    1,632 iterations in 5.09 seconds
Jax 1d convt:       17 iterations in 229.14 seconds
Using torch 2.4.1 on cpu
Torch 1d conv :  1,548 iterations in 5.00 seconds
Torch 1d convt:  2,211 iterations in 5.00 seconds

And on an Ubuntu machine with an RTX4090:

Using jax 0.4.33 on cuda
Jax 1d conv :   73,774 iterations in 5.00 seconds
Jax 1d convt:      388 iterations in 5.22 seconds
Using torch 2.4.1+cu124 on cuda
Torch 1d conv : 121,147 iterations in 5.02 seconds
Torch 1d convt: 168,121 iterations in 5.01 seconds

Here is the standalone code. Change the dev parameter to either 'cpu' or 'cuda' accordingly.

# -*- coding: utf-8 -*-
"""Slow Jax ConvT.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/19g_VmTrK0bScC6p5sqbuND7n0FVi4GqW

# Setup and benchmarking routine
"""

import jax
import torch
torch.backends.cudnn.benchmark = True

sz = 256
hp = sz//4
dev = 'cuda' # can also change to 'cpu', same thing holds

# Block until CUDA is done
def block( y):
    if type( y) == type( jax.numpy.array([])):
        y.block_until_ready()
    elif dev == 'cuda':
        torch.cuda.synchronize()

# Timing routine
def time_it( f):
    from time import time

    # Warmup
    for _ in range( 3):
        y = f()
    block( y)

    # Count how many passes we can queue in 5 sec
    c = 0
    t0 = time()
    while time()-t0 < 5:
        y += f()
        c += 1
    block( y)
    print( f'\t{c:6,d} iterations in {time()-t0:.2f} seconds')

"""# Jax convolutions

Note how the transpose convolution is orders of magnitude slower
"""

from functools import partial

print( 'Using jax', jax.__version__, 'on', dev)

# Jax regular 1d conv
@partial( jax.jit, backend=dev)
def jconvf( x, F):
    return jax.lax.conv_general_dilated( lhs=x, rhs=F,
        window_strides=(hp,), padding=((sz,sz),),
        dimension_numbers=('NCT','OIT','NCT'))

# Jax transpose 1d conv
@partial( jax.jit, backend=dev)
def jconvt( f, F):
    return jax.lax.conv_general_dilated( lhs=f, rhs=F,
        window_strides=(1,), lhs_dilation=(hp,), padding=((sz-1,sz-1),),
        dimension_numbers=('NCT','IOT','NCT'))

# Time them
x = jax.numpy.ones( (16, 1, sz*100))
F = jax.numpy.ones( (sz//2+1, 1, sz))
print( 'Jax 1d conv :', end='')
time_it( lambda: jconvf( x, F))

f = jax.numpy.ones( (16, sz//2+1, sz*100//hp))
F = jax.numpy.ones( (sz//2+1, 1, sz))
print( 'Jax 1d convt:', end='')
time_it( lambda: jconvt( f, F))

"""# Torch convolutions

Both convolutions types have comparable runtimes.  Regular convolution is on par with Jax, transpose is way faster than Jax.
"""

print( 'Using torch', torch.__version__, 'on', dev)

# Torch regular 1d conv
def tconvf( x, F):
    return torch.nn.functional.conv1d( x, F, stride=hp, padding=sz)

# Torch transpose 1d conv
def tconvt( f, F):
    return torch.nn.functional.conv_transpose1d( f, F, stride=hp)

# Time them
x = torch.ones( (16, 1, sz*100), device=dev)
F = torch.ones( (sz//2+1, 1, sz), device=dev)
print( 'Torch 1d conv :', end='')
time_it( lambda: tconvf( x, F))

f = torch.ones( (16, sz//2+1, sz*100//hp), device=dev)
F = torch.ones( (sz//2+1, 1, sz), device=dev)
print( 'Torch 1d convt:', end='')
time_it( lambda: tconvt( f, F))

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='f4a29f286e8e', release='6.1.85+', version='#1 SMP PREEMPT_DYNAMIC Thu Jun 27 21:05:47 UTC 2024', machine='x86_64')

$ nvidia-smi
Fri Sep 20 00:07:14 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   66C    P0              31W /  70W |  11493MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+
hawkinsp commented 1 month ago

Assigning @penpornk for the CPU part.

(The CUDA part probably should receive a look also, but the CPU problem is much worse. It probably means we're falling back to a naive implementation rather than using an optimized kernel.)