jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.54k stars 2.8k forks source link

Performance Discrepancy in `jax.value_and_grad` for Different `argnums` #23085

Open demon2036 opened 3 months ago

demon2036 commented 3 months ago

Description

I'm experiencing a significant performance difference when using jax.value_and_grad with different argnums values. Specifically, when setting argnums=0, the computation is about 3x faster compared to using argnums=2. Here is a brief description of my setup:

I am performing adversarial training, which requires me to compute gradients with respect to the input image in order to generate adversarial examples. However, this results in a substantial performance slowdown compared to when I'm computing gradients with respect to the model parameters (params). For example, on a TPUv4-8, the code runs at approximately 35.20it/s when using argnums=0 (i.e., computing gradients with respect to params), but drops to around 13.20it/s when using argnums=2 (i.e., computing gradients with respect to image).

Here’s a simplified version of the code:


def block_all(xs):
    jax.tree_util.tree_map(lambda x: x.block_until_ready(), xs)
return xs

state = create_train_state(args)
b = 32
batch = (jnp.ones((b, 224, 224, 3,)), jnp.ones((b,)))
img = batch[0]
img = img.astype(jnp.float32)
label = batch[1].astype(jnp.int32)

@functools.partial(jax.jit)
def test2(image, label, state, ):
    def adversarial_loss2(params, state, image, label):
        logits = state.apply_fn({"params": params}, image)
        loss_value = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits, label))
        return loss_value

    grads = jax.grad(adversarial_loss2, argnums=0)(state.params, state, image, label)
    return grads

for step in tqdm.trange(1, args.training_steps + 1, dynamic_ncols=True):
    out = block_all(test2(img, label, state, ))

### System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.31
jaxlib: 0.4.31
numpy:  2.0.1
python: 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0]
jax.devices (4 total, 4 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0) TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0) TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-f526ddc0-w-0', release='5.19.0-1022-gcp', version='#24~22.04.1-Ubuntu SMP Sun Apr 23 09:51:08 UTC 2023', machine='x86_64')
jakevdp commented 3 months ago

It's hard to say much definitive without a full reproduction, including what operations state.apply_fun are doing, but in general it's not surprising that taking derivatives with respect to different inputs leads to different performance characteristics. After all, they're different quantities computed via different sequences of operations, and those different sequences will have different performance characteristics.

If you want to get a sense for what's happening in each case, you might look at the HLO that's being generated: see e.g. Ahead of time lowering and compilation.

demon2036 commented 3 months ago

It's hard to say much definitive without a full reproduction, including what operations state.apply_fun are doing, but in general it's not surprising that taking derivatives with respect to different inputs leads to different performance characteristics. After all, they're different quantities computed via different sequences of operations, and those different sequences will have different performance characteristics.

If you want to get a sense for what's happening in each case, you might look at the HLO that's being generated: see e.g. Ahead of time lowering and compilation.

Hi @jakevdp. Apologies for the unclear code earlier—it was a bit complex due to the way I handle parameter passing. I've taken some time to simplify it and removed the use of state, now calling the model directly. This should make things more intuitive. In reality, the state was just a standard ViT-B/16 without any special modifications. Below is my updated main function code, where I no longer use state and made some adjustments when computing the gradients. argnums=0 calculates gradients with respect to the model parameters, while argnums=1 computes gradients with respect to the input image.

import jax.numpy as jnp
import jax
import optax
import tqdm
from modeling import ViT

def block_all(xs):
    jax.tree_util.tree_map(lambda x: x.block_until_ready(), xs)
    return xs

def main():
    rng = jax.random.PRNGKey(0)
    b, h, w, c = 8, 224, 224, 3
    batch = (jnp.ones((b, h, w, 3)), jnp.ones((b,)))
    img = batch[0].astype(jnp.float32)
    label = batch[1].astype(jnp.int32)

    model = ViT(layers=12, dim=768, heads=12, labels=1000, layerscale=True)
    params = model.init(rng, img)['params']

    def test2(image, label, params):
        def adversarial_loss2(params, image, label):
            logits = model.apply({"params": params}, image)
            loss_value = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits, label))
            return loss_value

        grads = jax.grad(adversarial_loss2, argnums=1)(params, image, label)
        return grads

    test2_jit = jax.jit(test2)

    for step in tqdm.trange(1, 10000000 + 1, dynamic_ncols=True):
        out = block_all(test2_jit(img, label, params))

if __name__ == "__main__":
    main()

And here is the ViT code:


from dataclasses import dataclass, fields
from functools import partial
from typing import Any, Literal
import einops
import flax.linen as nn
import flax.linen.initializers as init
import jax.numpy as jnp
from chex import Array

DenseGeneral = partial(nn.DenseGeneral, kernel_init=init.truncated_normal(0.02))
Dense = partial(nn.Dense, kernel_init=init.truncated_normal(0.02))
Conv = partial(nn.Conv, kernel_init=init.truncated_normal(0.02))

@dataclass
class ViTBase:
    layers: int = 12
    dim: int = 768
    heads: int = 12
    labels: int | None = 1000
    layerscale: bool = False

    patch_size: int = 16
    image_size: int = 224
    posemb: Literal["learnable", "sincos2d"] = "learnable"
    pooling: Literal["cls", "gap"] = "cls"

    dropout: float = 0.0
    droppath: float = 0.0
    grad_ckpt: bool = False

    @property
    def kwargs(self) -> dict[str, Any]:
        return {f.name: getattr(self, f.name) for f in fields(ViTBase)}

    @property
    def head_dim(self) -> int:
        return self.dim // self.heads

    @property
    def hidden_dim(self) -> int:
        return 4 * self.dim

    @property
    def num_patches(self) -> tuple[int, int]:
        return (self.image_size // self.patch_size,) * 2

class PatchEmbed(ViTBase, nn.Module):
    def setup(self):
        self.wte = Conv(
            self.dim,
            kernel_size=(self.patch_size, self.patch_size),
            strides=(self.patch_size, self.patch_size),
            padding="VALID",
        )
    def __call__(self, x: Array) -> Array:
        x = (self.wte(x)).reshape(x.shape[0], -1, self.dim)
        # x = x.reshape(x.shape[0], -1, self.dim)
        return x

class Attention(ViTBase, nn.Module):
    def setup(self):
        self.wq = DenseGeneral((self.heads, self.head_dim))
        self.wk = DenseGeneral((self.heads, self.head_dim))
        self.wv = DenseGeneral((self.heads, self.head_dim))
        self.wo = DenseGeneral(self.dim, axis=(-2, -1))
        self.drop = nn.Dropout(self.dropout)

    def __call__(self, x: Array, det: bool = True) -> Array:
        z = jnp.einsum("bqhd,bkhd->bhqk", self.wq(x) / self.head_dim ** 0.5, self.wk(x))
        z = jnp.einsum("bhqk,bkhd->bqhd", self.drop(nn.softmax(z), det), self.wv(x))
        return self.drop(self.wo(z), det)

class FeedForward(ViTBase, nn.Module):
    def setup(self):
        self.w1 = Dense(self.hidden_dim)
        self.w2 = Dense(self.dim)
        self.drop = nn.Dropout(self.dropout)

    def __call__(self, x: Array, det: bool = True) -> Array:
        return self.drop(self.w2(self.drop(nn.gelu(self.w1(x)), det)), det)

class ViTLayer(ViTBase, nn.Module):
    def setup(self):
        self.attn = Attention(**self.kwargs)
        self.ff = FeedForward(**self.kwargs)

        self.norm1 = nn.LayerNorm()
        self.norm2 = nn.LayerNorm()
        self.drop = nn.Dropout(self.droppath, broadcast_dims=(1, 2))

        self.scale1 = self.scale2 = 1.0
        if self.layerscale:
            self.scale1 = self.param("scale1", init.constant(1e-4), (self.dim,))
            self.scale2 = self.param("scale2", init.constant(1e-4), (self.dim,))

    def __call__(self, x: Array, det: bool = True) -> Array:
        x = x + self.drop(self.scale1 * self.attn(self.norm1(x), det), det)
        x = x + self.drop(self.scale2 * self.ff(self.norm2(x), det), det)
        return x

class ViT(ViTBase, nn.Module):
    def setup(self):
        self.embed = PatchEmbed(**self.kwargs)
        self.drop = nn.Dropout(self.dropout)

        # The layer class should be wrapped with `nn.remat` if `grad_ckpt` is enabled.
        layer_fn = nn.remat(ViTLayer) if self.grad_ckpt else ViTLayer
        self.layer = [layer_fn(**self.kwargs) for _ in range(self.layers)]

        self.norm = nn.LayerNorm()
        self.head = Dense(self.labels) if self.labels is not None else None

    def __call__(self, x: Array, det: bool = True) -> Array:
        # x = (images - IMAGENET_DEFAULT_MEAN) / IMAGENET_DEFAULT_STD
        # x = self.drop(self.embed(x), det)
        x = self.embed(x)

        for layer in self.layer:
            x = layer(x, det)
        x = self.norm(x)

        # If the classification head is not defined, then return the output of all
        # tokens instead of pooling to a single vector and then calculate class logits.
        if self.head is None:
            return x

        if self.pooling == "cls":
            x = x[:, 0, :]
        elif self.pooling == "gap":
            x = x.mean(1)
        return self.head(x)

When running the code above, I noticed that 124 it/s for argnums=0 and 15 it/s for argnums=1, indicating that computing gradients with respect to parameters is significantly faster than computing gradients with respect to the input. However,after computing gradients with respect to the parameters, calculating the gradients for the input should not take that much more time.

I suspected that the root cause might be in the PatchEmbed class, specifically in the Conv operation.

class PatchEmbed(ViTBase, nn.Module):
    def setup(self):
        self.wte = Conv(
            self.dim,
            kernel_size=(self.patch_size, self.patch_size),
            strides=(self.patch_size, self.patch_size),
            padding="VALID",
        )
    def __call__(self, x: Array) -> Array:
        x = (self.wte(x)).reshape(x.shape[0], -1, self.dim)
        # x = x.reshape(x.shape[0], -1, self.dim)
        return x

So, I tried to comment out the convolution and replace it with a simple reshape operation (since ViT-B's reshaped dimensions match the output dimensions, so no errors occurred). The new code looks like this:

class PatchEmbed(ViTBase, nn.Module):
    def setup(self):
        self.wte = Conv(
            self.dim,
            kernel_size=(self.patch_size, self.patch_size),
            strides=(self.patch_size, self.patch_size),
            padding="VALID",
        )
    def __call__(self, x: Array) -> Array:
        #x = (self.wte(x)).reshape(x.shape[0], -1, self.dim)
        x = x.reshape(x.shape[0], -1, self.dim)
        return x

Surprisingly, this resulted in 128 it/s for argsnum=0 and 195 it/s for argsnum=1, meaning that computing the gradients with respect to the input became much faster than computing gradients with respect to the parameters.

I looked into the flax.linen.Conv implementation, which mostly wraps JAX code without anything special, and I can't quite pinpoint the issue. I'm really puzzled by this behavior and would appreciate any insights you could share.Thank you!

demon2036 commented 3 months ago

I observed the same issue on the GPU. Since it seems that there is no capacity for TPUv4 right now, I can only report the situation on the GPU. I did further testing and found that the problem seems to stem from convolutions with a large kernel size. When I reduced the kernel size to 14, 8, or even lower, the problem of extreme slowness disappeared. On the contrary, when I increased the kernel size to 32 or higher, it became very slow, which really confuses me.

When running on the GPU, if the kernel size is greater than or equal to 16, this warning appears. There is no such warning on the TPU, and it was because of this warning that I pinpointed the problem to the convolution layer.

2024-08-17 17:59:38.059112: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.3 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
2024-08-17 17:59:49.517143: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng54{k2=5,k12=-1,k13=1,k14=2,k15=0,k17=384,k18=1,k23=0} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:02:30.275067: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 2m41.758073871s
Trying algorithm eng54{k2=5,k12=-1,k13=1,k14=2,k15=0,k17=384,k18=1,k23=0} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:02:31.275216: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng56{k2=8,k12=1,k13=1,k14=3,k15=0,k17=384,k18=1,k22=0,k23=0} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:04:14.350876: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 1m44.075748815s
Trying algorithm eng56{k2=8,k12=1,k13=1,k14=3,k15=0,k17=384,k18=1,k22=0,k23=0} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:04:15.351085: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng54{k2=1,k12=11,k13=0,k14=3,k15=0,k17=64,k18=1,k23=0} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:07:00.626353: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 2m46.275419584s
Trying algorithm eng54{k2=1,k12=11,k13=0,k14=3,k15=0,k17=64,k18=1,k23=0} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:07:01.626488: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng41{k2=0,k12=15,k13=2,k14=3,k15=0,k17=48,k18=1,k22=0,k23=0} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:12:24.320609: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 5m23.694191845s
Trying algorithm eng41{k2=0,k12=15,k13=2,k14=3,k15=0,k17=48,k18=1,k22=0,k23=0} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:12:25.320784: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng55{k2=8,k13=1,k14=3,k18=1,k22=0,k23=0} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:14:17.146178: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 1m52.825500921s
Trying algorithm eng55{k2=8,k13=1,k14=3,k18=1,k22=0,k23=0} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:14:17.146233: W external/tsl/tsl/framework/bfc_allocator.cc:291] Allocator (GPU_0_bfc) ran out of memory trying to allocate 3.82GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2024-08-17 18:14:18.146408: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng53{k2=5,k13=1,k14=2,k18=1,k23=0} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:17:22.133083: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 3m4.986817998s
Trying algorithm eng53{k2=5,k13=1,k14=2,k18=1,k23=0} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:17:23.133236: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng35{k2=5,k5=2,k14=6} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:21:55.612832: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 4m33.479687455s
Trying algorithm eng35{k2=5,k5=2,k14=6} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:21:56.613141: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng55{k2=3,k13=2,k14=2,k18=1,k22=0,k23=0} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:27:24.615041: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 5m29.002143139s
Trying algorithm eng55{k2=3,k13=2,k14=2,k18=1,k22=0,k23=0} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:27:25.615199: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng0{} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:28:13.739699: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 49.124595719s
Trying algorithm eng0{} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
jakevdp commented 3 months ago

I'm having trouble running your code: can you point to what the modeling package is?

These details help, but my first response still holds: differentiating with respect to different arguments results in a different computation, implemented in terms of a different sequence of operations, and it's entirely expected that different sequences of operations will have different performance characteristics.

Also a side note: instead of defining block_all, you could use jax.block_until_ready, which is basically a more robust and performant version of what you wrote.

demon2036 commented 3 months ago

I'm having trouble running your code: can you point to what the modeling package is?

These details help, but my first response still holds: differentiating with respect to different arguments results in a different computation, implemented in terms of a different sequence of operations, and it's entirely expected that different sequences of operations will have different performance characteristics.

Also a side note: instead of defining block_all, you could use jax.block_until_ready, which is basically a more robust and performant version of what you wrote.

Apologies for the unclear description earlier. The modeling package refers to the ViT code I provided above. That code represents the entire content of the modeling package. You can create a modeling.py file and copy the ViT code from above into it. It should be runnable this way.

demon2036 commented 3 months ago

I'm having trouble running your code: can you point to what the modeling package is?

These details help, but my first response still holds: differentiating with respect to different arguments results in a different computation, implemented in terms of a different sequence of operations, and it's entirely expected that different sequences of operations will have different performance characteristics.

Also a side note: instead of defining block_all, you could use jax.block_until_ready, which is basically a more robust and performant version of what you wrote.

Yes, it is expected that differentiating with respect to different argnums results in different computation times. However, if I understand correctly, differentiating with respect to the input should involve the following steps:

  1. Differentiate with respect to the parameters, propagating back to the first layer of the parameters (which is the derivative with respect to the parameters).
  2. Then, compute the derivative with respect to the input (this should be relatively quick).

Given this, the computation time for differentiating with respect to the input should not differ significantly from the time taken for differentiating with respect to the parameters. Therefore, it is surprising to see such a large difference in computation times.

Below is the modified code based on your suggestions. I removed the block_all function, replacing it with jax.block_until_ready, and added a modifiable parameter patch_size, which is passed to ViT to control the Conv patch size. I believe this will help you better reproduce the results from my code.

import jax.numpy as jnp
import jax
import optax
import tqdm
from modeling import ViT

def main():
    rng = jax.random.PRNGKey(0)
    b, h, w, c, = 8, 224, 224, 3
    patch_size = 16
    batch = (jnp.ones((b, h, w, 3)), jnp.ones((b,)))
    img = batch[0]
    img = img.astype(jnp.float32)
    label = batch[1].astype(jnp.int32)

    model = ViT(layers=12, dim=768, heads=12, labels=1000, layerscale=True, patch_size=patch_size, image_size=h)
    params = model.init(rng, img)['params']

    def test2(image, label, params, ):
        def adversarial_loss2(params, image, label):
            logits = model.apply({"params": params}, image)
            loss_value = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits, label))
            return loss_value

        grads = jax.grad(adversarial_loss2, argnums=1)(params, image, label)
        return grads

    test2_jit = jax.jit(test2, )

    for step in tqdm.trange(1, 10000000 + 1, dynamic_ncols=True):
        out = jax.block_until_ready(test2_jit(img, label, params))

if __name__ == "__main__":
    main()

You can modify the patch_size parameter. When patch_size=14, argnums=0 runs at 106.98it/s and argnums=1 at 160.28it/s, which I think is expected, with no issues.

However, when patch_size=16, argnums=0 runs at 122.59it/s and argnums=1 at 15.74it/s. I believe this is not as expected because, for the ViT model as a whole, the main computational time should be within the attention blocks, whether in forward or backward passes, not in the initial convolution. Regardless of whether the patch size is 16 or 14, the attention block's computation is significantly more costly. So the performance shouldn't degrade so much when changing patch_size=14 to patch_size=16. In fact, modifying the patch_size shouldn't impact the speed of gradient calculations for the parameters by such a large margin.

You mentioned that different arguments result in a different computation, implemented in terms of a different sequence of operations, and it's entirely expected that different sequences of operations will have different performance characteristics. I'm not sure if my understanding is correct, so please correct me if I'm wrong. According to my understanding, if we are differentiating with respect to argnums=1, which is the input, then only changing the patch_size shouldn't result in a different sequence of operations, right? After all, the same operation is being performed, just with a different kernel size. Therefore, I believe that such a significant difference in performance is abnormal.