openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.69k stars 433 forks source link

XLA does too many un-fused transposes #16914

Open ywrt opened 2 months ago

ywrt commented 2 months ago

(This is running on a Nvidia 4090 GPU, with jax '0.4.31')

I had got that is something like the example below. Here, the depth-wise convolution wants the input to be transposed from [batch, sequence, feature] into [batch, feature, sequence] so that it can apply the convolution along sequence.

The output from the convolution is used 3 times, and XLA generates at least 3 seperate (fused) transposes, each of which does a full read and write of memory. This is very slow and causes sadness.

Unfortunately, this example code doesn't reproduce the problem: The problem seems to be quite sensitive to the surrounding code, and trying to trim it down make most of the issue go away. A screen-grab from the profiled code somewhat shows the issue:

Screenshot from 2024-09-09 11-51-25

After the convolution is a loop-transpose_fusion, and then after the two cutlass gemm kernels, there are two input_transpose_fusion kernels, and then following the kernel__1 is another input_transpose_fusion. Each of these fusion is doing a full read/write of memory.

My main question is: How can I effectively debug this? eg. Is there a way to log all the GPU kernel calls along with the argument shapes? Is there some way to see why the transposes didn't fuse into a single kernel with 1 input and 3 outputs?

class Griffin(nn.Module):
  hidden: int = 512
  window: int = 16
  @nn.compact
  def __call__(self, x):
    left = nn.Dense(self.hidden)(x)
    left = jax.nn.silu(left)

    right = nn.Dense(self.hidden)(x)
    right = nn.Conv(right.shape[-1],                
                kernel_size=(self.window,),
                padding='CAUSAL',
                feature_group_count=right.shape[-1],
            )(right)

     # Input gate
    gate = nn.Dense(right.shape[-1])(right)
    gate = jax.nn.sigmoid(gate)
    gated = right * gate
    # Generate decay rate.
    decay = nn.Dense(right.shape[-1])(right)
    decay = jnp.exp(-8 * jax.nn.sigmoid(decay))

    right = kernel_recur(gated, decay) # Apply linear recurrence along axis=1

    o = left * right
    x += nn.Dense(x.shape[-1])(o)
    return x.mean()

net = Griffin()
params = net.init(jax.random.key(0), jnp.zeros((64, 1024, 256)))['params']

o = jax.jit(jax.value_and_grad(net.apply))({'params': params}, jnp.zeros((64, 1024, 256)))

module_0071.jit_apply.sm_8.9_gpu_after_optimizations.txt

cheshire commented 1 month ago

eg. Is there a way to log all the GPU kernel calls along with the argument shapes?

Yes, you could use logging in *thunks files, but I'm not sure it will help you, as at that point as you've pointed out fusion decisions have been done.

My main question is: How can I effectively debug this?

You could start by dumping HLO after every pass with --xla_dump_hlo_pass_re=.* and figuring out what pass makes bad decisions, or bisecting that somehow.

ywrt commented 1 month ago

Yes, you could use logging in *thunks files, but I'm not sure it will help you, as at that point as you've pointed out fusion decisions have been done.

Would you know how I can do that? In the first instance, I'm looking to be able to clearly see that the same memory is being read multiple times by different kernel calls. Is the some flag for the runtime that will have it (eg) log stream executions along with argument shapes?

You could start by dumping HLO after every pass with --xla_dump_hlo_pass_re=.* and figuring out what pass makes bad decisions, or bisecting that somehow.

Yes, I've been looking over this, but it's of lot of manual effort matching up the passes due to renaming et al (oh, and my limited familiarity! ).

Is there any documentation (beyond the source code) for things like the gpu_after_optimizations-buffer-assignment.txt file?

I guess I'm trying to see the memory read/write graph at the CUDA kernel boundaries. Is there some existing way to see this? If not, I guess I'll try to put together a script to generate a dot graph.

cheshire commented 1 month ago

Would you know how I can do that

You could look at various *thunk files, and inside the Run method they all know the buffers they are acting on, and you could either reuse existing VLOG lines or add your ones.

guess I'm trying to see the memory read/write graph at the CUDA kernel boundaries. Is there some existing way to see this?

Yes, that would be buffer assignment. You'll only see offsets there (and you can match them with instruction names in after-optimizations HLO) as the actual memory is allocated at runtime.

nullhook commented 1 month ago

fwiw, you can see how to parse the buffer assignment format here: https://github.com/openxla/xla/blob/main/xla/tools/driver.cc#L311