Open ywrt opened 2 months ago
eg. Is there a way to log all the GPU kernel calls along with the argument shapes?
Yes, you could use logging in *thunks files, but I'm not sure it will help you, as at that point as you've pointed out fusion decisions have been done.
My main question is: How can I effectively debug this?
You could start by dumping HLO after every pass with --xla_dump_hlo_pass_re=.*
and figuring out what pass makes bad decisions, or bisecting that somehow.
Yes, you could use logging in *thunks files, but I'm not sure it will help you, as at that point as you've pointed out fusion decisions have been done.
Would you know how I can do that? In the first instance, I'm looking to be able to clearly see that the same memory is being read multiple times by different kernel calls. Is the some flag for the runtime that will have it (eg) log stream executions along with argument shapes?
You could start by dumping HLO after every pass with --xla_dump_hlo_pass_re=.* and figuring out what pass makes bad decisions, or bisecting that somehow.
Yes, I've been looking over this, but it's of lot of manual effort matching up the passes due to renaming et al (oh, and my limited familiarity! ).
Is there any documentation (beyond the source code) for things like the gpu_after_optimizations-buffer-assignment.txt
file?
I guess I'm trying to see the memory read/write graph at the CUDA kernel boundaries. Is there some existing way to see this? If not, I guess I'll try to put together a script to generate a dot graph.
Would you know how I can do that
You could look at various *thunk files, and inside the Run
method they all know the buffers they are acting on, and you could either reuse existing VLOG lines or add your ones.
guess I'm trying to see the memory read/write graph at the CUDA kernel boundaries. Is there some existing way to see this?
Yes, that would be buffer assignment. You'll only see offsets there (and you can match them with instruction names in after-optimizations HLO) as the actual memory is allocated at runtime.
fwiw, you can see how to parse the buffer assignment format here: https://github.com/openxla/xla/blob/main/xla/tools/driver.cc#L311
(This is running on a Nvidia 4090 GPU, with jax '0.4.31')
I had got that is something like the example below. Here, the depth-wise convolution wants the input to be transposed from [batch, sequence, feature] into [batch, feature, sequence] so that it can apply the convolution along sequence.
The output from the convolution is used 3 times, and XLA generates at least 3 seperate (fused) transposes, each of which does a full read and write of memory. This is very slow and causes sadness.
Unfortunately, this example code doesn't reproduce the problem: The problem seems to be quite sensitive to the surrounding code, and trying to trim it down make most of the issue go away. A screen-grab from the profiled code somewhat shows the issue:
After the convolution is a
loop-transpose_fusion
, and then after the two cutlass gemm kernels, there are twoinput_transpose_fusion
kernels, and then following thekernel__1
is anotherinput_transpose_fusion
. Each of these fusion is doing a full read/write of memory.My main question is: How can I effectively debug this? eg. Is there a way to log all the GPU kernel calls along with the argument shapes? Is there some way to see why the transposes didn't fuse into a single kernel with 1 input and 3 outputs?
module_0071.jit_apply.sm_8.9_gpu_after_optimizations.txt