Open patrick-kidger opened 10 months ago
The dot_general followed by transpose seems like something the compiler should be able to fuse
Which seems to indicate that we might be better off avoiding composing JAX transforms, unfortunately.
I think this is too strong a conclusion – I think you could pretty easily generate the same jaxpr without vmap
and run into the same issue, and you could pretty easily use nested vmap
to generate a jaxpr that the compiler handles more efficiently.
I tried this on GPU, and found that the two programs generate similar runtimes and nearly identital compiled HLO. So this is more than likely an XLA:CPU only issue.
print(jax.jit(f1).lower(w, a).compile().as_text())
print(jax.jit(f2).lower(w, a).compile().as_text())
HloModule jit_f1, is_scheduled=true, entry_computation_layout={(f32[600,821]{1,0}, f32[32,100,821]{2,1,0})->f32[32,100,600]{2,1,0}}, allow_spmd_sharding_propagation_to_output={true}
ENTRY %main.9 (Arg_0.1: f32[600,821], Arg_1.2: f32[32,100,821]) -> f32[32,100,600] {
%Arg_1.2 = f32[32,100,821]{2,1,0} parameter(1), sharding={replicated}
%bitcast.11 = f32[821,3200]{0,1} bitcast(f32[32,100,821]{2,1,0} %Arg_1.2)
%Arg_0.1 = f32[600,821]{1,0} parameter(0), sharding={replicated}
%custom-call.1 = (f32[600,3200]{0,1}, s8[4194304]{0}) custom-call(f32[600,821]{1,0} %Arg_0.1, f32[821,3200]{0,1} %bitcast.11), custom_call_target="__cublas$gemm", metadata={op_name="jit(f1)/jit(main)/jit(f1)/dot_general[dimension_numbers=(((1,), (2,)), ((), ())) precision=None preferred_element_type=float32]" source_file="<ipython-input-1-ff468e097a57>" source_line=11}, backend_config={"alpha_real":1,"alpha_imag":0,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT","selected_algorithm":"10","lhs_stride":"492600","rhs_stride":"2627200","grad_x":false,"grad_y":false}
%get-tuple-element = f32[600,3200]{0,1} get-tuple-element((f32[600,3200]{0,1}, s8[4194304]{0}) %custom-call.1), index=0, metadata={op_name="jit(f1)/jit(main)/jit(f1)/dot_general[dimension_numbers=(((1,), (2,)), ((), ())) precision=None preferred_element_type=float32]" source_file="<ipython-input-1-ff468e097a57>" source_line=11}
ROOT %bitcast.3 = f32[32,100,600]{2,1,0} bitcast(f32[600,3200]{0,1} %get-tuple-element), frontend_attributes={fingerprint_before_lhs="82afe7c615af2843ef76993f3c5f0680"}, metadata={op_name="jit(f1)/jit(main)/jit(f1)/transpose[permutation=(1, 2, 0)]" source_file="<ipython-input-1-ff468e097a57>" source_line=23}
}
HloModule jit_f2, is_scheduled=true, entry_computation_layout={(f32[600,821]{1,0}, f32[32,100,821]{2,1,0})->f32[32,100,600]{2,1,0}}, allow_spmd_sharding_propagation_to_output={true}
ENTRY %main.8 (Arg_0.1: f32[600,821], Arg_1.2: f32[32,100,821]) -> f32[32,100,600] {
%Arg_1.2 = f32[32,100,821]{2,1,0} parameter(1), sharding={replicated}
%bitcast.5 = f32[3200,821]{1,0} bitcast(f32[32,100,821]{2,1,0} %Arg_1.2)
%Arg_0.1 = f32[600,821]{1,0} parameter(0), sharding={replicated}
%custom-call.1 = (f32[3200,600]{1,0}, s8[4194304]{0}) custom-call(f32[3200,821]{1,0} %bitcast.5, f32[600,821]{1,0} %Arg_0.1), custom_call_target="__cublas$gemm", metadata={op_name="jit(f2)/jit(main)/jit(f2)/a b, ... b -> ... a/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=float32]" source_file="<ipython-input-1-ff468e097a57>" source_line=15}, backend_config={"alpha_real":1,"alpha_imag":0,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["1"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT","selected_algorithm":"10","lhs_stride":"2627200","rhs_stride":"492600","grad_x":false,"grad_y":false}
%get-tuple-element = f32[3200,600]{1,0} get-tuple-element((f32[3200,600]{1,0}, s8[4194304]{0}) %custom-call.1), index=0, metadata={op_name="jit(f2)/jit(main)/jit(f2)/a b, ... b -> ... a/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=float32]" source_file="<ipython-input-1-ff468e097a57>" source_line=15}
ROOT %bitcast.1 = f32[32,100,600]{2,1,0} bitcast(f32[3200,600]{1,0} %get-tuple-element), frontend_attributes={fingerprint_before_lhs="e457c5c14342da4d62d047d3aeb1d0c5"}, metadata={op_name="jit(f2)/jit(main)/jit(f2)/a b, ... b -> ... a/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=float32]" source_file="<ipython-input-1-ff468e097a57>" source_line=15}
}
The dot_general followed by transpose seems like something the compiler should be able to fuse
Agreed!
I think this is too strong a conclusion
Hmm, my thinking here was that this looks like a case where (a) einsum
was doing its thing to try and act in an optimal way, whilst (b) vmap
was transposing dimensions for the sake of easier codegen (a common pattern in batching rules). And so the slow-down really was due to the use of a transform, morally speaking.
(I do agree you could find a way to generate each version either way if you wanted, though.)
I don't think there's any meaningful attempt at that kind of logic in the einsum
implementation; I think you just got unlucky to hit a codepath that is poorly optimized in XLA CPU.
Description
Consider the following two ways of computing a batch-batch-matvec. The approach that uses a double-
vmap
is actually twice as slow as using aneinsum
. (Which seems to indicate that we might be better off avoiding composing JAX transforms, unfortunately.)Printing out the jaxprs from each computation, it seems that the difference is the order of arguments to
dot_general
:Originally reported in https://github.com/patrick-kidger/equinox/issues/636, which I've reduced to a MWE here.
What jax/jaxlib version are you using?
0.4.23
Which accelerator(s) are you using?
CPU
Additional system info?
Python 3.11.3, Linux
NVIDIA GPU info
No response