Open demon2036 opened 3 months ago
It's hard to say much definitive without a full reproduction, including what operations state.apply_fun
are doing, but in general it's not surprising that taking derivatives with respect to different inputs leads to different performance characteristics. After all, they're different quantities computed via different sequences of operations, and those different sequences will have different performance characteristics.
If you want to get a sense for what's happening in each case, you might look at the HLO that's being generated: see e.g. Ahead of time lowering and compilation.
It's hard to say much definitive without a full reproduction, including what operations
state.apply_fun
are doing, but in general it's not surprising that taking derivatives with respect to different inputs leads to different performance characteristics. After all, they're different quantities computed via different sequences of operations, and those different sequences will have different performance characteristics.If you want to get a sense for what's happening in each case, you might look at the HLO that's being generated: see e.g. Ahead of time lowering and compilation.
Hi @jakevdp. Apologies for the unclear code earlier—it was a bit complex due to the way I handle parameter passing. I've taken some time to simplify it and removed the use of state
, now calling the model
directly. This should make things more intuitive. In reality, the state
was just a standard ViT-B/16 without any special modifications. Below is my updated main
function code, where I no longer use state
and made some adjustments when computing the gradients. argnums=0
calculates gradients with respect to the model parameters, while argnums=1
computes gradients with respect to the input image.
import jax.numpy as jnp
import jax
import optax
import tqdm
from modeling import ViT
def block_all(xs):
jax.tree_util.tree_map(lambda x: x.block_until_ready(), xs)
return xs
def main():
rng = jax.random.PRNGKey(0)
b, h, w, c = 8, 224, 224, 3
batch = (jnp.ones((b, h, w, 3)), jnp.ones((b,)))
img = batch[0].astype(jnp.float32)
label = batch[1].astype(jnp.int32)
model = ViT(layers=12, dim=768, heads=12, labels=1000, layerscale=True)
params = model.init(rng, img)['params']
def test2(image, label, params):
def adversarial_loss2(params, image, label):
logits = model.apply({"params": params}, image)
loss_value = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits, label))
return loss_value
grads = jax.grad(adversarial_loss2, argnums=1)(params, image, label)
return grads
test2_jit = jax.jit(test2)
for step in tqdm.trange(1, 10000000 + 1, dynamic_ncols=True):
out = block_all(test2_jit(img, label, params))
if __name__ == "__main__":
main()
And here is the ViT code:
from dataclasses import dataclass, fields
from functools import partial
from typing import Any, Literal
import einops
import flax.linen as nn
import flax.linen.initializers as init
import jax.numpy as jnp
from chex import Array
DenseGeneral = partial(nn.DenseGeneral, kernel_init=init.truncated_normal(0.02))
Dense = partial(nn.Dense, kernel_init=init.truncated_normal(0.02))
Conv = partial(nn.Conv, kernel_init=init.truncated_normal(0.02))
@dataclass
class ViTBase:
layers: int = 12
dim: int = 768
heads: int = 12
labels: int | None = 1000
layerscale: bool = False
patch_size: int = 16
image_size: int = 224
posemb: Literal["learnable", "sincos2d"] = "learnable"
pooling: Literal["cls", "gap"] = "cls"
dropout: float = 0.0
droppath: float = 0.0
grad_ckpt: bool = False
@property
def kwargs(self) -> dict[str, Any]:
return {f.name: getattr(self, f.name) for f in fields(ViTBase)}
@property
def head_dim(self) -> int:
return self.dim // self.heads
@property
def hidden_dim(self) -> int:
return 4 * self.dim
@property
def num_patches(self) -> tuple[int, int]:
return (self.image_size // self.patch_size,) * 2
class PatchEmbed(ViTBase, nn.Module):
def setup(self):
self.wte = Conv(
self.dim,
kernel_size=(self.patch_size, self.patch_size),
strides=(self.patch_size, self.patch_size),
padding="VALID",
)
def __call__(self, x: Array) -> Array:
x = (self.wte(x)).reshape(x.shape[0], -1, self.dim)
# x = x.reshape(x.shape[0], -1, self.dim)
return x
class Attention(ViTBase, nn.Module):
def setup(self):
self.wq = DenseGeneral((self.heads, self.head_dim))
self.wk = DenseGeneral((self.heads, self.head_dim))
self.wv = DenseGeneral((self.heads, self.head_dim))
self.wo = DenseGeneral(self.dim, axis=(-2, -1))
self.drop = nn.Dropout(self.dropout)
def __call__(self, x: Array, det: bool = True) -> Array:
z = jnp.einsum("bqhd,bkhd->bhqk", self.wq(x) / self.head_dim ** 0.5, self.wk(x))
z = jnp.einsum("bhqk,bkhd->bqhd", self.drop(nn.softmax(z), det), self.wv(x))
return self.drop(self.wo(z), det)
class FeedForward(ViTBase, nn.Module):
def setup(self):
self.w1 = Dense(self.hidden_dim)
self.w2 = Dense(self.dim)
self.drop = nn.Dropout(self.dropout)
def __call__(self, x: Array, det: bool = True) -> Array:
return self.drop(self.w2(self.drop(nn.gelu(self.w1(x)), det)), det)
class ViTLayer(ViTBase, nn.Module):
def setup(self):
self.attn = Attention(**self.kwargs)
self.ff = FeedForward(**self.kwargs)
self.norm1 = nn.LayerNorm()
self.norm2 = nn.LayerNorm()
self.drop = nn.Dropout(self.droppath, broadcast_dims=(1, 2))
self.scale1 = self.scale2 = 1.0
if self.layerscale:
self.scale1 = self.param("scale1", init.constant(1e-4), (self.dim,))
self.scale2 = self.param("scale2", init.constant(1e-4), (self.dim,))
def __call__(self, x: Array, det: bool = True) -> Array:
x = x + self.drop(self.scale1 * self.attn(self.norm1(x), det), det)
x = x + self.drop(self.scale2 * self.ff(self.norm2(x), det), det)
return x
class ViT(ViTBase, nn.Module):
def setup(self):
self.embed = PatchEmbed(**self.kwargs)
self.drop = nn.Dropout(self.dropout)
# The layer class should be wrapped with `nn.remat` if `grad_ckpt` is enabled.
layer_fn = nn.remat(ViTLayer) if self.grad_ckpt else ViTLayer
self.layer = [layer_fn(**self.kwargs) for _ in range(self.layers)]
self.norm = nn.LayerNorm()
self.head = Dense(self.labels) if self.labels is not None else None
def __call__(self, x: Array, det: bool = True) -> Array:
# x = (images - IMAGENET_DEFAULT_MEAN) / IMAGENET_DEFAULT_STD
# x = self.drop(self.embed(x), det)
x = self.embed(x)
for layer in self.layer:
x = layer(x, det)
x = self.norm(x)
# If the classification head is not defined, then return the output of all
# tokens instead of pooling to a single vector and then calculate class logits.
if self.head is None:
return x
if self.pooling == "cls":
x = x[:, 0, :]
elif self.pooling == "gap":
x = x.mean(1)
return self.head(x)
When running the code above, I noticed that 124 it/s
for argnums=0
and 15 it/s
for argnums=1
, indicating that computing gradients with respect to parameters is significantly faster than computing gradients with respect to the input. However,after computing gradients with respect to the parameters, calculating the gradients for the input should not take that much more time.
I suspected that the root cause might be in the PatchEmbed class, specifically in the Conv operation.
class PatchEmbed(ViTBase, nn.Module):
def setup(self):
self.wte = Conv(
self.dim,
kernel_size=(self.patch_size, self.patch_size),
strides=(self.patch_size, self.patch_size),
padding="VALID",
)
def __call__(self, x: Array) -> Array:
x = (self.wte(x)).reshape(x.shape[0], -1, self.dim)
# x = x.reshape(x.shape[0], -1, self.dim)
return x
So, I tried to comment out the convolution and replace it with a simple reshape operation (since ViT-B's reshaped dimensions match the output dimensions, so no errors occurred). The new code looks like this:
class PatchEmbed(ViTBase, nn.Module):
def setup(self):
self.wte = Conv(
self.dim,
kernel_size=(self.patch_size, self.patch_size),
strides=(self.patch_size, self.patch_size),
padding="VALID",
)
def __call__(self, x: Array) -> Array:
#x = (self.wte(x)).reshape(x.shape[0], -1, self.dim)
x = x.reshape(x.shape[0], -1, self.dim)
return x
Surprisingly, this resulted in 128 it/s
for argsnum=0
and 195 it/s
for argsnum=1, meaning that computing the gradients with respect to the input became much faster than computing gradients with respect to the parameters.
I looked into the flax.linen.Conv implementation, which mostly wraps JAX code without anything special, and I can't quite pinpoint the issue. I'm really puzzled by this behavior and would appreciate any insights you could share.Thank you!
I observed the same issue on the GPU. Since it seems that there is no capacity for TPUv4 right now, I can only report the situation on the GPU. I did further testing and found that the problem seems to stem from convolutions with a large kernel size. When I reduced the kernel size to 14, 8, or even lower, the problem of extreme slowness disappeared. On the contrary, when I increased the kernel size to 32 or higher, it became very slow, which really confuses me.
When running on the GPU, if the kernel size is greater than or equal to 16, this warning appears. There is no such warning on the TPU, and it was because of this warning that I pinpointed the problem to the convolution layer.
2024-08-17 17:59:38.059112: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.3 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
2024-08-17 17:59:49.517143: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng54{k2=5,k12=-1,k13=1,k14=2,k15=0,k17=384,k18=1,k23=0} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:02:30.275067: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 2m41.758073871s
Trying algorithm eng54{k2=5,k12=-1,k13=1,k14=2,k15=0,k17=384,k18=1,k23=0} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:02:31.275216: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng56{k2=8,k12=1,k13=1,k14=3,k15=0,k17=384,k18=1,k22=0,k23=0} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:04:14.350876: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 1m44.075748815s
Trying algorithm eng56{k2=8,k12=1,k13=1,k14=3,k15=0,k17=384,k18=1,k22=0,k23=0} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:04:15.351085: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng54{k2=1,k12=11,k13=0,k14=3,k15=0,k17=64,k18=1,k23=0} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:07:00.626353: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 2m46.275419584s
Trying algorithm eng54{k2=1,k12=11,k13=0,k14=3,k15=0,k17=64,k18=1,k23=0} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:07:01.626488: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng41{k2=0,k12=15,k13=2,k14=3,k15=0,k17=48,k18=1,k22=0,k23=0} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:12:24.320609: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 5m23.694191845s
Trying algorithm eng41{k2=0,k12=15,k13=2,k14=3,k15=0,k17=48,k18=1,k22=0,k23=0} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:12:25.320784: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng55{k2=8,k13=1,k14=3,k18=1,k22=0,k23=0} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:14:17.146178: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 1m52.825500921s
Trying algorithm eng55{k2=8,k13=1,k14=3,k18=1,k22=0,k23=0} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:14:17.146233: W external/tsl/tsl/framework/bfc_allocator.cc:291] Allocator (GPU_0_bfc) ran out of memory trying to allocate 3.82GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2024-08-17 18:14:18.146408: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng53{k2=5,k13=1,k14=2,k18=1,k23=0} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:17:22.133083: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 3m4.986817998s
Trying algorithm eng53{k2=5,k13=1,k14=2,k18=1,k23=0} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:17:23.133236: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng35{k2=5,k5=2,k14=6} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:21:55.612832: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 4m33.479687455s
Trying algorithm eng35{k2=5,k5=2,k14=6} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:21:56.613141: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng55{k2=3,k13=2,k14=2,k18=1,k22=0,k23=0} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:27:24.615041: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 5m29.002143139s
Trying algorithm eng55{k2=3,k13=2,k14=2,k18=1,k22=0,k23=0} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:27:25.615199: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng0{} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2024-08-17 18:28:13.739699: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 49.124595719s
Trying algorithm eng0{} for conv (f32[3,1,224,224]{3,2,1,0}, u8[0]{0}) custom-call(f32[3,768,16,16]{3,2,1,0}, f32[1,768,209,209]{3,2,1,0}), window={size=209x209 pad=208_208x208_208 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
I'm having trouble running your code: can you point to what the modeling
package is?
These details help, but my first response still holds: differentiating with respect to different arguments results in a different computation, implemented in terms of a different sequence of operations, and it's entirely expected that different sequences of operations will have different performance characteristics.
Also a side note: instead of defining block_all
, you could use jax.block_until_ready
, which is basically a more robust and performant version of what you wrote.
I'm having trouble running your code: can you point to what the
modeling
package is?These details help, but my first response still holds: differentiating with respect to different arguments results in a different computation, implemented in terms of a different sequence of operations, and it's entirely expected that different sequences of operations will have different performance characteristics.
Also a side note: instead of defining
block_all
, you could usejax.block_until_ready
, which is basically a more robust and performant version of what you wrote.
Apologies for the unclear description earlier. The modeling
package refers to the ViT code I provided above. That code represents the entire content of the modeling
package. You can create a modeling.py
file and copy the ViT code from above into it. It should be runnable this way.
I'm having trouble running your code: can you point to what the
modeling
package is?These details help, but my first response still holds: differentiating with respect to different arguments results in a different computation, implemented in terms of a different sequence of operations, and it's entirely expected that different sequences of operations will have different performance characteristics.
Also a side note: instead of defining
block_all
, you could usejax.block_until_ready
, which is basically a more robust and performant version of what you wrote.
Yes, it is expected that differentiating with respect to different argnums
results in different computation times. However, if I understand correctly, differentiating with respect to the input
should involve the following steps:
input
(this should be relatively quick).Given this, the computation time for differentiating with respect to the input should not differ significantly from the time taken for differentiating with respect to the parameters. Therefore, it is surprising to see such a large difference in computation times.
Below is the modified code based on your suggestions. I removed the block_all
function, replacing it with jax.block_until_ready
, and added a modifiable parameter patch_size
, which is passed to ViT to control the Conv patch size. I believe this will help you better reproduce the results from my code.
import jax.numpy as jnp
import jax
import optax
import tqdm
from modeling import ViT
def main():
rng = jax.random.PRNGKey(0)
b, h, w, c, = 8, 224, 224, 3
patch_size = 16
batch = (jnp.ones((b, h, w, 3)), jnp.ones((b,)))
img = batch[0]
img = img.astype(jnp.float32)
label = batch[1].astype(jnp.int32)
model = ViT(layers=12, dim=768, heads=12, labels=1000, layerscale=True, patch_size=patch_size, image_size=h)
params = model.init(rng, img)['params']
def test2(image, label, params, ):
def adversarial_loss2(params, image, label):
logits = model.apply({"params": params}, image)
loss_value = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits, label))
return loss_value
grads = jax.grad(adversarial_loss2, argnums=1)(params, image, label)
return grads
test2_jit = jax.jit(test2, )
for step in tqdm.trange(1, 10000000 + 1, dynamic_ncols=True):
out = jax.block_until_ready(test2_jit(img, label, params))
if __name__ == "__main__":
main()
You can modify the patch_size parameter. When patch_size=14
, argnums=0
runs at 106.98it/s
and argnums=1
at 160.28it/s
, which I think is expected, with no issues.
However, when patch_size=16
, argnums=0
runs at 122.59it/s
and argnums=1
at 15.74it/s
. I believe this is not as expected because, for the ViT model as a whole, the main computational time should be within the attention blocks, whether in forward or backward passes, not in the initial convolution. Regardless of whether the patch size is 16 or 14, the attention block's computation is significantly more costly. So the performance shouldn't degrade so much when changing patch_size=14
to patch_size=16
. In fact, modifying the patch_size
shouldn't impact the speed of gradient calculations for the parameters by such a large margin.
You mentioned that different arguments result in a different computation, implemented in terms of a different sequence of operations, and it's entirely expected that different sequences of operations will have different performance characteristics.
I'm not sure if my understanding is correct, so please correct me if I'm wrong. According to my understanding, if we are differentiating with respect to argnums=1
, which is the input
, then only changing the patch_size
shouldn't result in a different sequence of operations
, right? After all, the same operation is being performed, just with a different kernel size. Therefore, I believe that such a significant difference in performance is abnormal.
Description
I'm experiencing a significant performance difference when using
jax.value_and_grad
with differentargnums
values. Specifically, when settingargnums=0
, the computation is about 3x faster compared to usingargnums=2
. Here is a brief description of my setup:params
) and an input image (image
).argnums=0
to compute the gradient with respect toparams
, the execution is significantly faster than when I setargnums=2
to compute the gradient with respect toimage
.I am performing adversarial training, which requires me to compute gradients with respect to the input
image
in order to generate adversarial examples. However, this results in a substantial performance slowdown compared to when I'm computing gradients with respect to the model parameters (params
). For example, on a TPUv4-8, the code runs at approximately 35.20it/s when usingargnums=0
(i.e., computing gradients with respect toparams
), but drops to around 13.20it/s when usingargnums=2
(i.e., computing gradients with respect toimage
).Here’s a simplified version of the code: