google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.1k stars 2.66k forks source link

Gradient computations take more time for the first repetitions after compilation than de last #18622

Open MatDag opened 7 months ago

MatDag commented 7 months ago

Description

Hi, For benchmarking purposes, I need to measure the time spent to compute the gradient of some Flax model w.r.t. the model's parameters. The gradient is jitted, and a first run is performed for compilation. However, when running this code on a GPU, the five first computations are longer than the fifteen last, while computing the same thing. Is there an explanation for that?

import jax
import optax
import jax.numpy as jnp
from flax.training import common_utils
from transformers import FlaxResNetForImageClassification, ResNetConfig

from time import perf_counter

def cross_entropy_loss(logits, labels):
    one_hot_labels = common_utils.onehot(labels, num_classes=2)
    xentropy = optax.softmax_cross_entropy(logits=logits,
                                           labels=one_hot_labels)
    return jnp.mean(xentropy)

config_resnet50 = ResNetConfig(
    num_channels=3,
    embedding_size=64,
    hidden_sizes=[256, 512, 1024, 2048],
    depths=[3, 4, 6, 3],
    layer_type='bottleneck',
    hidden_act='relu',
    downsample_in_first_stage=False,
    out_features=None,
    out_indices=None,
)
model = FlaxResNetForImageClassification(config_resnet50)

def loss_fn(params, batch):
    """loss function used for training."""
    logits = model._module.apply(params, batch['images']).logits
    loss = cross_entropy_loss(logits, batch['labels'])
    return loss

if __name__ == '__main__':
    batch_size = 16
    n_reps = 20
    key = jax.random.PRNGKey(0)
    key, subkey = jax.random.split(key)
    batch = {
        'images': jax.random.normal(key, (batch_size, 224, 224, 3)),
        'labels': jax.random.randint(subkey, (batch_size,), 0, 2)
    }
    params = model.params

    grad_fun = jax.jit(lambda x: jax.grad(loss_fn)(x, batch))

    grad_fun(params)  # First run for compilation
    for _ in range(n_reps):
        start = perf_counter()
        jax.block_until_ready(grad_fun(params))
        print(perf_counter() - start)

Output:

0.027811646927148104
0.027487624902278185
0.0274412389844656
0.027404130902141333
0.027959060855209827
0.02018922194838524
0.018810193985700607
0.01876934664323926
0.018772422801703215
0.01853539375588298
0.018571130000054836
0.01854165457189083
0.01851730002090335
0.018671086058020592
0.018460053950548172
0.018476856406778097
0.01846574479714036
0.01866668788716197
0.01846056431531906
0.018483687192201614

What jax/jaxlib version are you using?

jax v0.4.7, jaxlib v0.4.7+cuda11.cudnn86

Which accelerator(s) are you using?

GPU

Additional system info

Python 3.11, Linux

NVIDIA GPU info

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.48.07    Driver Version: 515.48.07    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A100-PCI...  On   | 00000000:C3:00.0 Off |                    0 |
| N/A   33C    P0    35W / 250W |      0MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
jakevdp commented 7 months ago

I suspect the issue is that you're not calling block_until_ready on the first invocation of grad_fun(params), and so the subsequent calculations are being asynchronously dispatched while the device is still busy. When I change your compilation run to this:

jax.block_until_ready(grad_fun(params))  # First run for compilation

I see more consistent timing in the first several runs of the benchmark.

MatDag commented 7 months ago

Thank you for answering. I tried your solution and got the same results, unfortunately.

jakevdp commented 6 months ago

I'm unable to reproduce this on a Colab T4 TPU runtime with jax v0.4.20 and the block_until_ready that I suggested above.

Can you try updating your jax and jaxlib version and see if that affects the result?