google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.06k stars 2.66k forks source link

Extremely slow GPU execution #7024

Open NeilGirdhar opened 3 years ago

NeilGirdhar commented 3 years ago

The following code is almost instantaneous (<1ms) on the CPU, but is extremely slow on the GPU (7s). I'm trying to track down the source of the problem. I have pared down my code from 5000 lines down to 80 lines, and I don't think I can remove any more. I have added comments in places that I found that have surprising (to me) effects on the GPU run time.

How can I make this code run faster on the GPU than it does on the CPU? What am I doing wrong?

from functools import partial
from typing import Any
import haiku as hk
import jax.numpy as jnp
from contexttimer import Timer
from jax import jit
from jax.experimental import enable_x64
from jax.lax import while_loop
from jax.nn import sigmoid, softplus
from jax.random import PRNGKey, normal, split
from tjax.dataclasses import dataclass  # Equivalent to flax.struct.dataclass

class Linear(hk.Module):
    def __init__(self, output_size: int):
        super().__init__()
        self.output_size = output_size

    def __call__(self, inputs):
        w = hk.get_parameter("w", [inputs.shape[-1], self.output_size],
                             inputs.dtype,  # Passing dtype costs 23%!
                             init=jnp.zeros)
        # Calling softplus costs 32%!
        return jnp.dot(inputs, softplus(w))

class NoisyMLP(hk.Module):
    def __init__(self, layer_sizes):
        super().__init__()
        self.layers = [Linear(output_size) for output_size in layer_sizes]

    def __call__(self, inputs):
        out = inputs
        for layer in self.layers:
            out = layer(out)
            out = sigmoid(out)  # Sigmoid costs 10%!
        return out

@dataclass
class SamplerState:
    code_momentum: Any
    rng: Any
    iterations: Any

shape = (1,)

def nat_to_exp(natural_explanation):
    mlp = NoisyMLP((12, *shape))
    return mlp(natural_explanation)

def haiku_weight_initializer() -> None:
    nat_to_exp(jnp.zeros(shape))

def state_needs_iteration(maximum_iterations, state) -> bool:
    return state.iterations < maximum_iterations

def update_state(weights, state):
    leak_rng, new_rng = split(state.rng)
    nat_to_exp_f = hk.transform(nat_to_exp).apply
    force = nat_to_exp_f(weights, None, state.code_momentum)
    new_code_momentum = force + normal(leak_rng, force.shape)
    return SamplerState(new_code_momentum, new_rng, state.iterations + 1)

def find_fixed_point(weights, initial_state, maximum_iterations):
    return while_loop(partial(state_needs_iteration, maximum_iterations),
                      partial(update_state, weights),
                      initial_state)

@partial(jit, static_argnums=())  # Passing maximum_iterations non-statically costs 43%!
def infer_encoding(weights, initial_rng, maximum_iterations):
    initial_sampler_state = SamplerState(jnp.zeros(shape), initial_rng, 0)
    return find_fixed_point(weights, initial_sampler_state, maximum_iterations)

with enable_x64():  # Enabling 64-bit costs 50%.
    rng = PRNGKey(12)
    weight_rng, inference_rng = split(rng)
    weights = hk.transform(haiku_weight_initializer).init(weight_rng)
    for _ in range(10):
        with Timer() as timer:
            infer_encoding(weights, inference_rng, 8000)
        print(timer.elapsed)
mattjj commented 3 years ago

Thanks for raising this, and for working so hard to minimize it!

The best tool here is to use profiling. If you can get a profile showing a realistic workload, we can really dig in to what improvements can be made (either to your code, to JAX itself, or to XLA:GPU).

There's one effect that would explain one of your comments, though I don't think it would explain the code as written being slow. General while_loops can require returning control to the host on each iteration just to decide whether to dispatch another iteration of the loop body on the GPU, incurring expensive synchronization and transfer overheads (which would loom large when the loop body itself is cheap). But in XLA:GPU there's a "for loop optimization" which is meant to notice when the loop actually has a statically fixed trip count (as it does here, at least with the code as written!) so that control need not be returned to the host on each iteration.

Could you share a profile of the execution so we can dig in?

NeilGirdhar commented 3 years ago

Thank you so much for the lightning fast reply!

The first thing I tried was profiling, but I just couldn't make sense of all the information. I even tried using jax.profiler.TraceAnnotation to improve the output, but it was still opaque to me. Here's the capture from my actual program: tfevents.tar.gz

General while_loops can require returning control to the host on each iteration just to decide whether to dispatch another iteration of the loop body on the GPU, i

Thanks for explaining that! I can definitely convert these loops to something else. I guess I should convert these to jax.lax.scan? I did that, and the time falls by 47%. I set unroll to 2 to get the time down 20% from there. (It is still hundreds of times slower than the CPU version though.)

I wonder if your comment here is related?

mattjj commented 3 years ago

I guess I should convert these to jax.lax.scan?

Well, I wouldn't suggest using scan, because I'm pretty sure XLA will do that optimization for you when you make the trip count static (as you did in the example). So there's nothing additional to be gained by using scan I don't think.

Yes that comment is related.

        # Calling softplus costs 32%!
        return jnp.dot(inputs, softplus(w))

For this part, because XLA:GPU is likely calling into pre-packaged CUDA kernels for the dot (unless it knows it can generate better code than the closed-source Nvidia-provided kernels, which is rare), adding the softplus may mean that you have to launch two kernels (one for the dot, one for the presumably fused xla-generated softplus computation) per call to the MLP (i.e. at least two per loop iteration), rather than just one (just the dot).

(Tangentially, XLA:TPU has much more flexibility here: since it generates the dot routine too, it can fuse things like elementwise operations into the loads and stores of the dot operation, and indeed on TPU any jitted function leads to one big optimized XLA:TPU program, rather than separate kernels as on XLA:GPU. )

By the way, if instead of splitting the RNG on every iteration, you just split it once into a big array (with leading axis size maximum_iterations) and, say, scan over it (or just index into it with the iteration counter). That can also save kernel launches in the loop body, though it'll mean your program uses more memory.

The overall theme here is to try to minimize the number of kernel launches per loop iteration.

I haven't looked at your profile yet, but I'll try to get the chance soon!

mattjj commented 3 years ago

About the profile: would it be possible to share a screenshot of the TensorBoard visualization? A screenshot is easy to act on, and to show to others!

NeilGirdhar commented 3 years ago

Well, I wouldn't suggest using scan, because I'm pretty sure XLA will do that optimization for you when you make the trip count static (as you did in the example). So there's nothing additional to be gained by using scan I don't think.

Okay that explains why the speedup matches just making it static. One benefit to scan is that it's an assertion that the iteration limit is static. There's no jax.is_static, right? (I guess you can try: int(x); except TypeError:)

By the way, if instead of splitting the RNG on every iteration, you just split it once into a big array

That's a great idea. I'll try that.

About the profile: would it be possible to share a screenshot of the TensorBoard visualization? A screenshot is easy to act on, and to show to others!

Sure, which tab do you want me to screenshot?

NeilGirdhar commented 3 years ago

I made your two suggested changes (below), and the runtime went from 0.69s down to 0.13s:

It's still much slower than running this on the CPU, and I don't understand why.

In my real code, I can unroll the while loop over 8000 iterations by replacing it with a while loop over 8 iterations of scans of 100 iterations. I can generate the random numbers outside the loop even though this means restructuring my code significantly. (It also means that Haiku's way of handling random number generation is inappropriate for use in loops.)

from functools import partial
from typing import Any
import haiku as hk
import jax.numpy as jnp
from contexttimer import Timer
from jax import jit
from jax.experimental import enable_x64
from jax.lax import scan
from jax.nn import sigmoid, softplus
from jax.random import PRNGKey, normal, split
from tjax.dataclasses import dataclass

class Linear(hk.Module):
    def __init__(self, output_size: int):
        super().__init__()
        self.output_size = output_size

    def __call__(self, inputs):
        w = hk.get_parameter("w", [inputs.shape[-1], self.output_size],
                             inputs.dtype,  # Passing dtype costs 23%!
                             init=jnp.zeros)
        # Calling softplus costs 32%!
        return jnp.dot(inputs, softplus(w))

class NoisyMLP(hk.Module):
    def __init__(self, layer_sizes):
        super().__init__()
        self.layers = [Linear(output_size) for output_size in layer_sizes]

    def __call__(self, inputs):
        out = inputs
        for layer in self.layers:
            out = layer(out)
            out = sigmoid(out)  # Sigmoid costs 10%!
        return out

@dataclass
class SamplerState:
    code_momentum: Any
    rng: Any
    iterations: Any

shape = (1,)

def nat_to_exp(natural_explanation):
    mlp = NoisyMLP((12, *shape))
    return mlp(natural_explanation)

def haiku_weight_initializer() -> None:
    nat_to_exp(jnp.zeros(shape))

def update_state(weights, state, diffusion):
    nat_to_exp_f = hk.transform(nat_to_exp).apply
    force = nat_to_exp_f(weights, None, state.code_momentum)
    new_code_momentum = force + diffusion
    return SamplerState(new_code_momentum, state.rng, state.iterations + 1)

def find_fixed_point(weights, initial_state, maximum_iterations):
    def f(state, diffusion):
        return update_state(weights, state, diffusion), None
    leak_rng, new_rng = split(initial_state.rng)
    diffusion = normal(leak_rng, (maximum_iterations,) + initial_state.code_momentum.shape)
    retval = scan(f, initial_state, diffusion, length=maximum_iterations,  unroll=2)[0]
    retval = retval.replace(rng=new_rng)
    return retval

@partial(jit, static_argnums=(2,))
def infer_encoding(weights, initial_rng, maximum_iterations):
    initial_sampler_state = SamplerState(jnp.zeros(shape), initial_rng, 0)
    return find_fixed_point(weights, initial_sampler_state, maximum_iterations)

with enable_x64():  # Enabling 64-bit costs 50%.
    rng = PRNGKey(12)
    weight_rng, inference_rng = split(rng)
    weights = hk.transform(haiku_weight_initializer).init(weight_rng)
    for _ in range(10):
        with Timer() as timer:
            infer_encoding(weights, inference_rng, 8000)
        print(timer.elapsed)
mattjj commented 3 years ago

In my real code, I can unroll the while loop over 8000 iterations by replacing it with a while loop over 8 iterations of scans of 100 iterations

By the way, scan has a useful unroll option which will do this for you, so it should be easy to experiment:

https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html

(I should've mentioned that earlier!)

mattjj commented 3 years ago

There's no jax.is_static, right? (I guess you can try: int(x); except TypeError:)

That's right.

Sure, which tab do you want me to screenshot?

Basically the one that looks something like this:

image

I think it's called the Trace Viewer, described here in the TensorBoard docs.

It's still much slower than running this on the CPU, and I don't understand why.

I think once we look at the right part of the Trace Viewer output, we should be able to figure out what the computer is spending its time doing.

mattjj commented 3 years ago

There's good documentation on how to use the trace viewer, but unfortunately it looks like it's all Google-internal, and hasn't been open-sourced yet...

NeilGirdhar commented 3 years ago

By the way, scan has a useful unroll option which will do this for you, so it should be easy to experiment:

I know that scan has an unroll, but while doesn't 😄. I realize that unrolling a while loop can change behavior since the condition is checked less often, but it doesn't matter to me since I'm using it for finding a fixed point. I considered proposing adding a corresponding unroll to while, but I can do it on my side, so it's not such a big deal.

It's also possible to create an unrolling while loop that doesn't change its behavior.

NeilGirdhar commented 3 years ago

Here are the screenshots you requested. Again, thanks for all your help! Trace Screenshot 1 Trace Screenshot 2

NeilGirdhar commented 3 years ago

adding the softplus may mean that you have to launch two kernels (one for the dot, one for the presumably fused xla-generated softplus computation)

That would definitely explain some of the slowness, but I don't see how that's better than just generating one kernel to do both things at once? Even if this were applied to an array with a million elements, it still seems like it would be faster spawn one kernel instead of two.

Is this XLA generator code something I can look at? I guess it's not here?

mattjj commented 3 years ago

Is this XLA generator code something I can look at?

XLA:GPU is all open-source!

That would definitely explain some of the slowness, but I don't see how that's better than just generating one kernel to do both things at once? Even if this were applied to an array with a million elements, it still seems like it would be faster spawn one kernel instead of two.

XLA:GPU's hands are tied here, because the way to generate the fastest GPU kernels for matmul and conv are proprietary. Only Nvidia can do it, and they release those kernels as binaries. That's what cuBLAS, cuDNN, etc are. There are many such kernels; XLA:GPU does autotuning and kernel selection to choose the best routines for your array shapes and specific GPU hardware. See here for example. But because these are proprietary pre-built kernels, it can't e.g. fuse operations into them. That's why loop bodies may have to be separated into multiple separate kernels.

NeilGirdhar commented 2 years ago

So is there anything I can do finally to make my program run faster?

NeilGirdhar commented 2 years ago

I hope you don't mind that I'm looking at this again. From what I undersatnd https://github.com/openai/triton tries to produce single GPU kernels. Is there any hope of JAX doing something like this in future from XLA? What a dream it would be to have fast GPU execution with the convenience of JAX's compiler.

nouiz commented 1 year ago

Hi Neil,

Did you saw this jax/triton project https://github.com/jax-ml/jax-triton?

Also, what are your shapes? In this example, you have: shape = (1,). If you have layer of only 1 neuron, this is very very tiny shapes.

NeilGirdhar commented 1 year ago

Did you saw this jax/triton project https://github.com/jax-ml/jax-triton?

No, I had not! Thank you for sharing that. I was aware of Triton, so this is very exciting!

Also, what are your shapes? In this example, you have: shape = (1,). If you have layer of only 1 neuron, this is very very tiny shapes.

Yes, for now my shapes are less than 100 as I work on getting my ideas working. I still feel like in an ideal world, it would not require bouncing between the CPU and GPU no matter what the shapes are?

nouiz commented 1 year ago

With those very small size, it will be hard for the GPU to be efficient automatically in the short term. I think many optimizations will be needed in addition to fix the not bouncing between the CPU and GPU. For example, it will probably need more aggressive fusion then what XLA currently does. We are working on more XLA fusion, but do not expect any quick automatic fixes.

The quickest path would be some manual kernel via custom ops or maybe via Triton. Do you agree with that?

NeilGirdhar commented 1 year ago

The quickest path would be some manual kernel via custom ops or maybe via Triton. Do you agree with that?

You know better than I do :smile: If you say that's the quickest path, I believe you. I will look into Triton within the next couple weeks and get back to you? How do I learn about writing custom ops?

nouiz commented 1 year ago

This is the best documentation about JAX custom CUDA ops: https://github.com/dfm/extending-jax

NeilGirdhar commented 1 year ago

Thanks. I will look into that! Would I have access to Jax's automatic differentiation? Or would I need to do the differentiation myself and then implement that in CUDA?

nouiz commented 1 year ago

I only very briefly looked at the custom op myself. I'm sure that you can register another custom op for the gradient. But there is also other option I think.

@mattjj, do you know if we can provide a forward graph where the gradient will be taken of for the custom op gradient?

Neil, at worst, you can print your forward and backward graph. From this, you can find that is the gradient graph. Then you can create a jax graph that does it and ask JAX to use it for the gradient: https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html

mattjj commented 1 year ago

You'd need to write differentiation rules for any new Primitives introduced. In the dfm tutorial there's code in src/kepler_jax/kepler_jax.py which shows how to define a JVP rule for the primitive introduced.

These custom differentiation rules are not meant for custom kernels (though they kinda work there...) but rather for JAX-traceable code. When introducing a new Primitive, as in the dfm tutorial (and as I would recommend for a custom kernel), you just attach transformation rules to the primitive directly.

Once you have a differentiation rule for your primitive, you can differentiate any function that applies it along with other JAX primitives.

(Someday JAX may be able to generate derivatives for Triton code automatically. It's something we're looking at, but it's a long way off.)

NeilGirdhar commented 1 year ago

First of all, Jax Triton looks amazing! Yes, it should solve my problem with quite of bit of work on my side. So thank you for that.

However, I have some thoughts that I'd like to get feedback on.

My problem boils down to an internal scan that evaluates something like

x[i+1] = x[i] + k * f_bwd(z - f(x[i]))

Where f is a "forwards pass" function of primals, and f_bwd is the corresponding backward pass of cotangents.

If f is a simple neural network with noise, then it's fairly straightforward to write this in Triton. The backward pass can easily be written, but it's annoying. Why am I doing this? Jax is already calculating the backward pass, and I might make mistakes that I'll have to debug. That's what I meant when I asked if I would have access to Jax's automatic differentiation. It appears that I'll have to manually differentiate f and then implement that in Triton.

I also thought about how I would write this in Triton. I could just manually write every fused kernel I need. And at the end of it, I'd have a library of pieces of kernels that I could compose to do what I need. These would probably be extra methods on "modules" (from Haiku or Flax) that would do things like:

Then I would have some way of composing multiple modules into a single fused kernel. This entails two functions triton_forwards and triton_backwards that would produce a forward or backward Triton function. Each function

Then I thought: why am I doing all this? Wouldn't it make much more sense to have a conversion from XLA to Triton?

I understand that Triton is a very limited language. I understand that it may not be possible to convert everything that XLA can do to Triton. But I'm not doing anything that crazy. If the converter wants to bail out if I try to do something like take a hyperbolic sin, that's fine! I'm just doing ordinary multiplications, exponentiations, addition, etc.

And I remember Matt explaining to me that Nvidia's kernels (e.g. matrix multiplication) are better optimized than anything the user can do. But I'm pretty sure that the last time I looked at this, my runtime is dominated by kernel spawning. Even if Triton is 50% as fast as Nvidia's hand-crafted kernels, the ability to fuse literally hundreds of kernels together would more than compensate. And the reason it's hundreds is because I have a scan (described above), and each iteration of the scan is a whole new set of kernel spawning.

So, my question boils down to: Why have we decided on Jax-Triton as the solution? Why not convert XLA to Triton as best you can, and then we can keep programming in the Jax we love?

nouiz commented 1 year ago

It appears that I'll have to manually differentiate f and then implement that in Triton.

You can use jax to compute the grad graph and print it. That way, you do not need to do it manually, it would be semi-manual, so a little better, but clearly not ideal.

You "only" need to write some custom CUDA kernel. You do not need to do any memory allocation. You still use JAX. You write the custom kernel that XLA doesn't generate fast code for. All the rest will stay in JAX.

You can use jax-triton to write the custom kernel. It doesn't need to do anything else to my understanding. JAX-Triton help make the bridge between the two tools.

I think it is less work then what you describe above.

Your idea of having a XLA/triton backend would be hard to implement as we currently can't guide what to do in triton vs XLA.

Maybe a simpler thing could be a decorator on JAX expression that get converted to Triton. This is highly hypothetical as I didn't look enough at Triton. But I suppose this doesn't exist now.

NeilGirdhar commented 1 year ago

You do not need to do any memory allocation. You still use JAX. You write the custom kernel that XLA doesn't generate fast code for.

I understand. What I'm trying to say is that the custom kernel writing that I would be doing is tantamount to compilation, which is already done by the XLA compiler. I want to simply write in Jax using its primitives. Triton is a different language. As you point out in your comment, it is possible to convert Jax to Triton, so that's what I'm asking for when I say "I want to program in the Jax I love".

I think it is less work then what you describe above.

Just so we're clear, I'm suggesting that instead of writing Triton code, I would write a rudimentary module-to-Triton converter.

You are right that writing a single Triton kernel is less work. However, I don't think this is a good approach in my case:

Your idea of having a XLA/triton backend would be hard to implement as we currently can't guide what to do in triton vs XLA.

Can you elaborate on this?

I'd like to generalize my suggestion. What I want is for XLA to produce fused kernels rather than the kernels it's producing now. Why can't it produce fused kernels? If I can write fused kernels in Triton, surely the XLA compiler can produce such kernels?

The benefit of XLA producing such kernels is that I wouldn't have to worry about producing backwards passes, which involves:

Maybe a simpler thing could be a decorator on JAX expression that get converted to Triton.

Yes, I considered something like this. This would be a fantastic step in the right direction. If you consider my more general suggestion, then what I really want is a decorator to demand that the XLA compiler produces a fused kernel for a decorated function.

PS Salut Frederic! We met ages ago when you were working on Theano. Very nice to see you here :smile:

nouiz commented 1 year ago

I agree that having XLA does the fusion for you would be great. But this isn't a trivial fusion to add. First, XLA would need to learn how to generates dot product itself. Then he would need to do the external loop too. I do not know anyone with the time to start that anytime soon.

I suppose you won't learn XLA and implement it yourself. So the only short term option I see is to go via Triton.

A jaxpr->triton converter can be done more easily then a XLA->triton converted. Hopefully this is something you could take on. But I'm not able to estimate the time it would require.

For the gradient, the simplest way that I see is to use JAX grad on the forward graph. Then reuse the jaxpr->triton a second time on that gradient graph to get a second operation in the gradient.

NeilGirdhar commented 1 year ago

I agree that having XLA does the fusion for you would be great. But this isn't a trivial fusion to add. First, XLA would need to learn how to generates dot product itself. Then he would need to do the external loop too. I do not know anyone with the time to start that anytime soon.

Okay! That makes perfect sense. Thank you for explaining. But please do keep me posted if you learn about something like being in the works so that I can put my time into other things.

A jaxpr->triton converter can be done more easily then a XLA->triton converted. Hopefully this is something you could take on.

Interesting! I honestly thought jaxpr was XLA :laughing: Isn't it true that jitted Jax code has a Jaxpr? If so, this would solve my problem, and if you think it's reasonable, then it's definitely worth examining. Are there any tools for working with Jaxpr? Is there any documentation?

For the gradient, the simplest way that I see is to use JAX grad on the forward graph. Then reuse the jaxpr->triton a second time on that gradient graph to get a second operation in the gradient.

Yes, that's what I was thinking too. That's the beauty of having such a converter.

nouiz commented 1 year ago

JAX build jaxpr expression. Then this is converted to XLA. vmap, grad and others jax functions works on the jaxpr IR. I do not know if there is any tools or documentation for it. @mattjj do you know?

jakevdp commented 1 year ago

There is some information on jaxprs here: https://jax.readthedocs.io/en/latest/jaxpr.html