jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.51k stars 2.8k forks source link

Order of argments to `dot_general` significantly affects performance #19327

Open patrick-kidger opened 10 months ago

patrick-kidger commented 10 months ago

Description

Consider the following two ways of computing a batch-batch-matvec. The approach that uses a double-vmap is actually twice as slow as using an einsum. (Which seems to indicate that we might be better off avoiding composing JAX transforms, unfortunately.)

import functools as ft
import timeit

import jax
import jax.numpy as jnp

@jax.jit
@ft.partial(jax.vmap, in_axes=(None, 0))
@ft.partial(jax.vmap, in_axes=(None, 0))
def f1(w, a):
    return w @ a

@jax.jit
def f2(w, a):
    return jnp.einsum('a b, ... b -> ... a', w, a)

w = jnp.ones((600, 821))
a = jnp.ones((32, 100, 821))

print(min(timeit.repeat(lambda: f1(w, a), number=1, repeat=100)))
print(min(timeit.repeat(lambda: f2(w, a), number=1, repeat=100)))

# 0.020185229001072003
# 0.010345280999899842

Printing out the jaxprs from each computation, it seems that the difference is the order of arguments to dot_general:

print(jax.make_jaxpr(f1)(w, a))
print(jax.make_jaxpr(f2)(w, a))

# { lambda ; a:f32[600,821] b:f32[32,100,821]. let
#     c:f32[32,100,600] = pjit[
#       name=f1
#       jaxpr={ lambda ; d:f32[600,821] e:f32[32,100,821]. let
#           f:f32[600,32,100] = dot_general[
#             dimension_numbers=(([1], [2]), ([], []))
#             preferred_element_type=float32
#           ] d e
#           g:f32[32,100,600] = transpose[permutation=(1, 2, 0)] f
#         in (g,) }
#     ] a b
#   in (c,) }
# { lambda ; a:f32[600,821] b:f32[32,100,821]. let
#     c:f32[32,100,600] = pjit[
#       name=f2
#       jaxpr={ lambda ; d:f32[600,821] e:f32[32,100,821]. let
#           f:f32[32,100,600] = dot_general[
#             dimension_numbers=(([2], [1]), ([], []))
#             preferred_element_type=float32
#           ] e d
#         in (f,) }
#     ] a b
#   in (c,) }

Originally reported in https://github.com/patrick-kidger/equinox/issues/636, which I've reduced to a MWE here.

What jax/jaxlib version are you using?

0.4.23

Which accelerator(s) are you using?

CPU

Additional system info?

Python 3.11.3, Linux

NVIDIA GPU info

No response

jakevdp commented 10 months ago

The dot_general followed by transpose seems like something the compiler should be able to fuse

Which seems to indicate that we might be better off avoiding composing JAX transforms, unfortunately.

I think this is too strong a conclusion – I think you could pretty easily generate the same jaxpr without vmap and run into the same issue, and you could pretty easily use nested vmap to generate a jaxpr that the compiler handles more efficiently.

jakevdp commented 10 months ago

I tried this on GPU, and found that the two programs generate similar runtimes and nearly identital compiled HLO. So this is more than likely an XLA:CPU only issue.

print(jax.jit(f1).lower(w, a).compile().as_text())
print(jax.jit(f2).lower(w, a).compile().as_text())
HloModule jit_f1, is_scheduled=true, entry_computation_layout={(f32[600,821]{1,0}, f32[32,100,821]{2,1,0})->f32[32,100,600]{2,1,0}}, allow_spmd_sharding_propagation_to_output={true}

ENTRY %main.9 (Arg_0.1: f32[600,821], Arg_1.2: f32[32,100,821]) -> f32[32,100,600] {
  %Arg_1.2 = f32[32,100,821]{2,1,0} parameter(1), sharding={replicated}
  %bitcast.11 = f32[821,3200]{0,1} bitcast(f32[32,100,821]{2,1,0} %Arg_1.2)
  %Arg_0.1 = f32[600,821]{1,0} parameter(0), sharding={replicated}
  %custom-call.1 = (f32[600,3200]{0,1}, s8[4194304]{0}) custom-call(f32[600,821]{1,0} %Arg_0.1, f32[821,3200]{0,1} %bitcast.11), custom_call_target="__cublas$gemm", metadata={op_name="jit(f1)/jit(main)/jit(f1)/dot_general[dimension_numbers=(((1,), (2,)), ((), ())) precision=None preferred_element_type=float32]" source_file="<ipython-input-1-ff468e097a57>" source_line=11}, backend_config={"alpha_real":1,"alpha_imag":0,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT","selected_algorithm":"10","lhs_stride":"492600","rhs_stride":"2627200","grad_x":false,"grad_y":false}
  %get-tuple-element = f32[600,3200]{0,1} get-tuple-element((f32[600,3200]{0,1}, s8[4194304]{0}) %custom-call.1), index=0, metadata={op_name="jit(f1)/jit(main)/jit(f1)/dot_general[dimension_numbers=(((1,), (2,)), ((), ())) precision=None preferred_element_type=float32]" source_file="<ipython-input-1-ff468e097a57>" source_line=11}
  ROOT %bitcast.3 = f32[32,100,600]{2,1,0} bitcast(f32[600,3200]{0,1} %get-tuple-element), frontend_attributes={fingerprint_before_lhs="82afe7c615af2843ef76993f3c5f0680"}, metadata={op_name="jit(f1)/jit(main)/jit(f1)/transpose[permutation=(1, 2, 0)]" source_file="<ipython-input-1-ff468e097a57>" source_line=23}
}

HloModule jit_f2, is_scheduled=true, entry_computation_layout={(f32[600,821]{1,0}, f32[32,100,821]{2,1,0})->f32[32,100,600]{2,1,0}}, allow_spmd_sharding_propagation_to_output={true}

ENTRY %main.8 (Arg_0.1: f32[600,821], Arg_1.2: f32[32,100,821]) -> f32[32,100,600] {
  %Arg_1.2 = f32[32,100,821]{2,1,0} parameter(1), sharding={replicated}
  %bitcast.5 = f32[3200,821]{1,0} bitcast(f32[32,100,821]{2,1,0} %Arg_1.2)
  %Arg_0.1 = f32[600,821]{1,0} parameter(0), sharding={replicated}
  %custom-call.1 = (f32[3200,600]{1,0}, s8[4194304]{0}) custom-call(f32[3200,821]{1,0} %bitcast.5, f32[600,821]{1,0} %Arg_0.1), custom_call_target="__cublas$gemm", metadata={op_name="jit(f2)/jit(main)/jit(f2)/a b, ... b -> ... a/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=float32]" source_file="<ipython-input-1-ff468e097a57>" source_line=15}, backend_config={"alpha_real":1,"alpha_imag":0,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["1"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT","selected_algorithm":"10","lhs_stride":"2627200","rhs_stride":"492600","grad_x":false,"grad_y":false}
  %get-tuple-element = f32[3200,600]{1,0} get-tuple-element((f32[3200,600]{1,0}, s8[4194304]{0}) %custom-call.1), index=0, metadata={op_name="jit(f2)/jit(main)/jit(f2)/a b, ... b -> ... a/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=float32]" source_file="<ipython-input-1-ff468e097a57>" source_line=15}
  ROOT %bitcast.1 = f32[32,100,600]{2,1,0} bitcast(f32[3200,600]{1,0} %get-tuple-element), frontend_attributes={fingerprint_before_lhs="e457c5c14342da4d62d047d3aeb1d0c5"}, metadata={op_name="jit(f2)/jit(main)/jit(f2)/a b, ... b -> ... a/dot_general[dimension_numbers=(((2,), (1,)), ((), ())) precision=None preferred_element_type=float32]" source_file="<ipython-input-1-ff468e097a57>" source_line=15}
}
patrick-kidger commented 10 months ago

The dot_general followed by transpose seems like something the compiler should be able to fuse

Agreed!

I think this is too strong a conclusion

Hmm, my thinking here was that this looks like a case where (a) einsum was doing its thing to try and act in an optimal way, whilst (b) vmap was transposing dimensions for the sake of easier codegen (a common pattern in batching rules). And so the slow-down really was due to the use of a transform, morally speaking.

(I do agree you could find a way to generate each version either way if you wanted, though.)

jakevdp commented 10 months ago

I don't think there's any meaningful attempt at that kind of logic in the einsum implementation; I think you just got unlucky to hit a codepath that is poorly optimized in XLA CPU.