jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.53k stars 2.8k forks source link

Scanning over reduction rather than reducing over scan is a powerful optimisation - but hard to implement #4968

Open SamPruden opened 4 years ago

SamPruden commented 4 years ago

Scanning over reduction can be significantly faster and more memory efficient than reducing over the stacked output of scan. That's hard to articulate so I'll give a simple example.

# SLOW AND MEMORY INEFFICIENT
# Produces one large output then does a single large reduction.
# Final reduction size is length * batch size
scan(lambda carry, _: (jnp.sin(carry), jnp.sin(carry)), x, None, length=10)[1].mean()

# FAST
# Reduces inside the scan, so no huge array is ever created.
# Final reduction size is just length
scan(lambda carry, _: (jnp.sin(carry), jnp.sin(carry).mean()), x, None, length=10)[1].mean()

I've found this to be a very powerful optimisation, but it can be awkward to work with. In my case, the stacked output of the scan is the output of my model, but in the training loop I do an MSE reduction over it to get the loss. Training speed can be significantly improved by pushing the loss reduction up into the scan, but it requires two separate implementations of the model: one for inference, and one for training.

Note that there are two levels of performance improvement available here. One is that moving the reduction into the loop speeds up execution by a factor of ~2. The other is that the output of the scan may well be the memory bottleneck because it's a length * batch_size allocation, and by moving the reduction inside we reduce it to simply a length allocation. In my case this allowed me to increase my batch size by a factor of half of the scan length, which was 16!

It would be great if this optimisation could be done automatically on a single implementation of the model. I don't know where the right place to do that is; perhaps this is an issue better filed with XLA. JAX already does some similar transformations - would it make sense for JAX to do this?

This is a fastmath style optimisation and may well accumulate floating point error.

I'm considering prototyping a custom JAX interpreter to do this for my narrow use case, although I'm not sure that will be a good use of my time compared to just living with the two parallel implementations.

Benchmark code

import jax
import jax.numpy as jnp
from timeit import timeit

def basic_inner(carry, _):
  a = jnp.sin(carry)
  return a, a

def integ_inner(carry, _):
  a = jnp.sin(carry)
  return a, a.mean()

basic = jax.jit(lambda x, l: jax.lax.scan(basic_inner, x, None, length=l)[1].mean(), 1)
integ = jax.jit(lambda x, l: jax.lax.scan(integ_inner, x, None, length=l)[1].mean(), 1)
basic_g = jax.jit(jax.grad(lambda x, l: jax.lax.scan(basic_inner, x, None, length=l)[1].mean()), 1)
integ_g = jax.jit(jax.grad(lambda x, l: jax.lax.scan(integ_inner, x, None, length=l)[1].mean()), 1)

def benchmark(batch_size, length):
  x = jnp.ones(batch_size)

  for _ in range(10): basic(x, length).block_until_ready()
  basic_time = timeit(lambda: basic(x, length).block_until_ready(), number = 100)

  for _ in range(10): integ(x, length).block_until_ready()
  integ_time = timeit(lambda: integ(x, length).block_until_ready(), number = 100)

  for _ in range(10): basic_g(x, length).block_until_ready()
  basic_g_time = timeit(lambda: basic_g(x, length).block_until_ready(), number = 100)

  for _ in range(10): integ_g(x, length).block_until_ready()
  integ_g_time = timeit(lambda: integ_g(x, length).block_until_ready(), number = 100)

  print(f" -- BATCH SIZE: {batch_size}, LENGTH: {length} -- ")
  print(f"basic     : {basic_time:.2f}, integ     : {integ_time:.2f} - factor: {(basic_time / integ_time):.2f}")
  print(f"basic grad: {basic_g_time:.2f}, integ grad: {integ_g_time:.2f} - factor: {(basic_g_time / integ_g_time):.2f}")
  print()

benchmark(1 * 2048, 10)
benchmark(2 * 2048, 10)
benchmark(4 * 2048, 10)
benchmark(8 * 2048, 10)
benchmark(64 * 2048, 10)
benchmark(128 * 2048, 10)
benchmark(256 * 2048, 10)
benchmark(512 * 2048, 10)
benchmark(1024 * 2048, 10)

benchmark(1 * 2048, 100)
benchmark(2 * 2048, 100)
benchmark(4 * 2048, 100)
benchmark(8 * 2048, 100)
benchmark(64 * 2048, 100)
benchmark(128 * 2048, 100)
benchmark(256 * 2048, 100)
benchmark(512 * 2048, 100)
benchmark(1024 * 2048, 100)

# Memory test
integ(jnp.ones(10 * 32 * 64 * 2048), 100).block_until_ready()
print("Still alive after integ!")

# Commented out because it OOMs and crashes
# basic(jnp.ones(10 * 32 * 64 * 2048), 100).block_until_ready()
# print("Still alive after basic!")

CPU results

 -- BATCH SIZE: 2048, LENGTH: 10 -- 
basic     : 0.06, integ     : 0.04 - factor: 1.34
basic grad: 0.06, integ grad: 0.06 - factor: 1.04

 -- BATCH SIZE: 4096, LENGTH: 10 -- 
basic     : 0.07, integ     : 0.05 - factor: 1.42
basic grad: 0.10, integ grad: 0.07 - factor: 1.31

 -- BATCH SIZE: 8192, LENGTH: 10 -- 
basic     : 0.39, integ     : 0.07 - factor: 5.53
basic grad: 0.42, integ grad: 0.12 - factor: 3.50

 -- BATCH SIZE: 16384, LENGTH: 10 -- 
basic     : 0.48, integ     : 0.11 - factor: 4.32
basic grad: 0.57, integ grad: 0.50 - factor: 1.16

 -- BATCH SIZE: 2048, LENGTH: 100 -- 
basic     : 0.54, integ     : 0.12 - factor: 4.42
basic grad: 0.59, integ grad: 0.49 - factor: 1.20

 -- BATCH SIZE: 4096, LENGTH: 100 -- 
basic     : 0.73, integ     : 0.22 - factor: 3.39
basic grad: 0.87, integ grad: 0.69 - factor: 1.26

 -- BATCH SIZE: 8192, LENGTH: 100 -- 
basic     : 1.11, integ     : 0.39 - factor: 2.86
basic grad: 1.42, integ grad: 1.04 - factor: 1.36

 -- BATCH SIZE: 16384, LENGTH: 100 -- 
basic     : 1.89, integ     : 0.74 - factor: 2.56
basic grad: 2.56, integ grad: 1.76 - factor: 1.46

Still alive after integ!

GPU results

 -- BATCH SIZE: 2048, LENGTH: 10 -- 
basic     : 0.06, integ     : 0.06 - factor: 1.00
basic grad: 0.11, integ grad: 0.10 - factor: 1.07

 -- BATCH SIZE: 4096, LENGTH: 10 -- 
basic     : 0.07, integ     : 0.07 - factor: 0.91
basic grad: 0.10, integ grad: 0.09 - factor: 1.07

 -- BATCH SIZE: 8192, LENGTH: 10 -- 
basic     : 0.06, integ     : 0.07 - factor: 0.91
basic grad: 0.09, integ grad: 0.09 - factor: 0.98

 -- BATCH SIZE: 16384, LENGTH: 10 -- 
basic     : 0.07, integ     : 0.07 - factor: 0.89
basic grad: 0.09, integ grad: 0.10 - factor: 0.90

 -- BATCH SIZE: 131072, LENGTH: 10 -- 
basic     : 0.07, integ     : 0.08 - factor: 0.97
basic grad: 0.11, integ grad: 0.09 - factor: 1.14

 -- BATCH SIZE: 262144, LENGTH: 10 -- 
basic     : 0.12, integ     : 0.08 - factor: 1.50
basic grad: 0.17, integ grad: 0.12 - factor: 1.39

 -- BATCH SIZE: 524288, LENGTH: 10 -- 
basic     : 0.19, integ     : 0.13 - factor: 1.49
basic grad: 0.30, integ grad: 0.22 - factor: 1.37

 -- BATCH SIZE: 1048576, LENGTH: 10 -- 
basic     : 0.32, integ     : 0.20 - factor: 1.58
basic grad: 0.56, integ grad: 0.38 - factor: 1.48

 -- BATCH SIZE: 2097152, LENGTH: 10 -- 
basic     : 0.56, integ     : 0.34 - factor: 1.67
basic grad: 1.08, integ grad: 0.69 - factor: 1.58

 -- BATCH SIZE: 2048, LENGTH: 100 -- 
basic     : 0.31, integ     : 0.31 - factor: 1.00
basic grad: 0.50, integ grad: 0.53 - factor: 0.94

 -- BATCH SIZE: 4096, LENGTH: 100 -- 
basic     : 0.31, integ     : 0.32 - factor: 0.97
basic grad: 0.53, integ grad: 0.52 - factor: 1.03

 -- BATCH SIZE: 8192, LENGTH: 100 -- 
basic     : 0.30, integ     : 0.32 - factor: 0.93
basic grad: 0.52, integ grad: 0.51 - factor: 1.02

 -- BATCH SIZE: 16384, LENGTH: 100 -- 
basic     : 0.31, integ     : 0.31 - factor: 1.01
basic grad: 0.50, integ grad: 0.52 - factor: 0.96

 -- BATCH SIZE: 131072, LENGTH: 100 -- 
basic     : 0.34, integ     : 0.32 - factor: 1.09
basic grad: 0.66, integ grad: 0.53 - factor: 1.25

 -- BATCH SIZE: 262144, LENGTH: 100 -- 
basic     : 0.72, integ     : 0.38 - factor: 1.87
basic grad: 1.31, integ grad: 0.83 - factor: 1.57

 -- BATCH SIZE: 524288, LENGTH: 100 -- 
basic     : 1.47, integ     : 0.89 - factor: 1.66
basic grad: 2.81, integ grad: 1.80 - factor: 1.56

 -- BATCH SIZE: 1048576, LENGTH: 100 -- 
basic     : 2.67, integ     : 1.57 - factor: 1.71
basic grad: 5.39, integ grad: 3.32 - factor: 1.62

 -- BATCH SIZE: 2097152, LENGTH: 100 -- 
basic     : 5.07, integ     : 2.80 - factor: 1.81
basic grad: 10.60, integ grad: 6.28 - factor: 1.69

Still alive after integ!
SamPruden commented 3 years ago

I've taken a quick run at implementing a custom JAX transformation for this. It seems promising so far! This was built with no real planning or testing, and the documentation doesn't really exist for this so I just hacked something together that works. It's probably a very poor implementation and doubtless has bugs and unaccounted for edge cases galore. It seems to be working on the basic things that I've tested it on, but I really only did it to explore the transformation capabilities and to see how hard it is.

This only moves a simple sum reduction up. That means it isn't actually useful for me yet as I need it to be able to move a whole MSE reduction up. That will require significantly more manipulation and I may or may not bother taking a run at that.

If anybody with familiarity with the internal API would like to comment on whether I'm approaching this in the right way, that would be appreciated! I'm wondering if there are some utilities that make this type of transformation easier? Perhaps I should be creating new jaxprs from scratch instead of patching existing ones?

Major things still on my TODO list:

I may get these things done, or I may never look at this code again.

I don't know how well this would work as an inbuilt automatic transformation or whether JAX would be the correct level to do this on. If it's a good optimisation then I would think it should happen in XLA. There could still be a good reason why this transformation isn't done, and I just haven't found it yet.

def push_up_reductions(fun):
  from functools import wraps
  from jax import core
  from jax import lax
  # from jax.util import safe_map

  def scan_predicate(eqn):
    # TODO: Skip if reduction already in place
    # Necessary for doing multiple iterations
    if eqn.primitive != lax.scan_p: return False
    if eqn.outvars[1] == core.dropvar: return False
    # Excluded because I'm too lazy to think about them for now
    if eqn.params['reverse']: return False
    if eqn.params['unroll'] != 1: return False
    return True

  def get_dependant_reductions(eqns, var):
    # TODO: Handle transparent operations like scalar multiply which reductions can be moved through
    dependants = [eqn for eqn in eqns if var in eqn.invars]
    return dependants if all(dep.primitive == lax.reduce_sum_p for dep in dependants) else None

  def patch_scan_body(closed_jaxpr, axes):
    jaxpr = closed_jaxpr.jaxpr
    old_outvar = jaxpr.outvars[1]
    out_aval = closed_jaxpr.out_avals[1]
    newvar = core.gensym([jaxpr])
    sumshape = [out_aval.shape[i] for i in range(out_aval.ndim) if not i in axes]
    sumvar = newvar(jax.ShapedArray(sumshape, out_aval.dtype))
    outvar = newvar(out_aval)

    # Add reduction
    jaxpr.eqns.append(core.JaxprEqn([old_outvar], [sumvar], lax.reduce_sum_p, {'axes': axes}, None))

    # Broadcast to original ndim so that outer reductions can be left as is
    dims = tuple(i for i in range(out_aval.ndim) if i not in axes)
    shape = tuple(1 if i in axes else out_aval.shape[i] for i in range(out_aval.ndim))
    jaxpr.eqns.append(core.JaxprEqn(
      [sumvar], [outvar], lax.broadcast_in_dim_p,
      {'broadcast_dimensions': dims, 'shape': shape}, None)
    )

  def patch_eqns(jaxpr):
    scans = filter(scan_predicate, jaxpr.eqns)
    for scan in scans:
      dependants = get_dependant_reductions(jaxpr.eqns, scan.outvars[1])
      if dependants == None: continue

      # If there are multiple dependant reductions, we can move up the intersection of their axes
      movable_shape = set.intersection(*(set(dep.params['axes']) for dep in dependants))
      movable_shape.discard(0)
      movable_shape = tuple(movable_shape)
      if movable_shape == (): continue

      # Patch the existing scan body
      inner_axes = tuple(a - 1 for a in movable_shape)
      patch_scan_body(scan.params['jaxpr'], inner_axes)

  def wrapped(*args, **kwargs):
    closed_jaxpr = jax.make_jaxpr(fun)(*args, **kwargs)
    patch_eqns(closed_jaxpr.jaxpr)
    return core.eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)[0]

  return wraps(fun)(wrapped)
jekbradbury commented 3 years ago

One reason this hasn't been implemented in XLA might be that there's no native "scan" HLO, so this transformation would be a fairly involved kind of loop-invariant code motion rather than the pattern match that it is for jaxprs. @blakehechtman to maybe comment more?

SamPruden commented 3 years ago

One reason this hasn't been implemented in XLA might be that there's no native "scan" HLO, so this transformation would be a fairly involved kind of loop-invariant code motion rather than the pattern match that it is for jaxprs.

Ah that's interesting. Would it make sense to introduce a transformation like this into JAX then? I don't know whether it's considered appropriate for JAX to do these types of optimisations itself, or whether the philosophy is that jit just passes everything transparently to XLA. The floating point differences that it introduces may be an issue too; would it need to be toggleable or even off by default because of that?

Of course, if XLA could detect it that would benefit other frameworks too, so that seems the most appropriate place if it's practical to do.

mattjj commented 3 years ago

This is great stuff! Thanks for providing so much detail, and for digging in so readily.

Would it make sense to introduce a transformation like this into JAX then?

So far the philosophy is indeed that JAX doesn't do rewrites spanning multiple primitives. That doesn't always have to be the case, but it helps us keep things simple so it seems worth sticking with as long as we can.

The floating point issues might also make this optimization not always desirable, though I'm not sure. It can be hard to find compiler optimizations that are uniformly profitable. We've seen other similar examples.

But there may be another possibility here: going meta a bit, maybe this is an example of a domain- or problem-specific compiler optimization, and so the right answer is to make optimizations like these as easy as possible for users to add themselves. That way expert users can get exactly the optimizations they want, and we don't need them to be profitable across all reasonable applications (like they might have to be to include in XLA or even JAX core). That could include, for example, some basic pattern-matching-rewrite machinery.

There may be a tradeoff there: we need to be defensive about internal APIs so that we can maintain the ability to revise JAX internals. If we lose the ability to revise things because so many people depend on internal details, JAX will stagnate. But we may be able to find a balance: for example, if these APIs were only used by researchers, there might be a lower expectation of API stability and long-term compatibility. (We've already found that intrepid researchers and other experts end up building cool things with JAX internal APIs, and so mainly this is about expectations management, so that people stay happy even if we have to revise these JAX APIs and revise their code.)

Anyway, just wanted to share those high-level thoughts. Having a user-malleable compiler was one of the original motivating visions behind JAX, and this same thought has come up again recently as well.

blakehechtman commented 3 years ago

does this make even more sense? scan(lambda (carry, sum), _: ((jnp.sin(carry), sum+jnp.sin(carry).mean()) / 10)),(x, 0), None, length=10)[0][1]

blakehechtman commented 3 years ago

excuse my potentially terrible python

SamPruden commented 3 years ago

Thanks for the detailed reply Matt! I've been sleeping on itl here's where my thoughts are today.

So far the philosophy is indeed that JAX doesn't do rewrites spanning multiple primitives. That doesn't always have to be the case, but it helps us keep things simple so it seems worth sticking with as long as we can.

I can see why! Doing this comprehensively has quite a lot of tricky rules, and I'm not at all looking forward to the prospect of having to write tests for it.

The floating point issues might also make this optimization not always desirable, though I'm not sure. It can be hard to find compiler optimizations that are uniformly profitable. We've seen other similar examples.

Yeah I'm not sure what the policy is here. As I understand it, you do have some parts of fastmath on my default, so I presume it's not an instant deal breaker?

But there may be another possibility here: going meta a bit, maybe this is an example of a domain- or problem-specific compiler optimization, and so the right answer is to make optimizations like these as easy as possible for users to add themselves. That way expert users can get exactly the optimizations they want, and we don't need them to be profitable across all reasonable applications (like they might have to be to include in XLA or even JAX core). That could include, for example, some basic pattern-matching-rewrite machinery.

That type of rewrite machinery would be very welcome. I admit that I'd hoped I might find something like that when I looked into doing this.

I'm not sure leaving this to each user makes sense in this case. The problem is that this is not a good use of my time. If I just wanted to get this problem solved, I should just refactor and live with the two parallel implementations of the model. It would probably take about 20 minutes. I'd have to live with slightly uglier code and the risk of bugs where the two versions get out of sync, but that would be more practical than modifying the compiler. The peace of mind of knowing that it works rather than worrying you may have a bug in your compiler mod is worth a lot on its own. If we left this elegant solution up to the user, I don't think they'd do it.

It's also something that significant number of users may benefit from, but only a few would think to do. It may be doing them a disservice to not give this to them out of the box.

An obvious user of this would be Trax. If I get it working, I may run some benchmarks with their RNN layers and models. If it improves things significantly on their components, they would probably want to implement it. One option would be for them to do it in their project, another would be for JAX to provide it to them.

If this were an official feature you wouldn't have to worry about the internal API problem because you would just maintain this. Would it make sense to provide this as a transformation that can be dropped in as easily as grad, but that never gets automatically applied? That might be a good balance between ease of use and not changing the default behaviour.

SamPruden commented 3 years ago

does this make even more sense? scan(lambda (carry, sum), _: ((jnp.sin(carry), sum+jnp.sin(carry).mean()) / 10)),(x, 0), None, length=10)[0][1]

I woke up thinking a similar thing this morning. I can performance test this as a special case, but my instinct is that it's not going to be worth doing. It's not actually as powerful as the sum_reduce version because it only deals with summation over the first axis, whereas reduction can go over many. I would hope that XLA would make these about the same anyway.

I'll give a slightly more comprehensive overview of the type of transformation I'm aiming at.

Scan + MSE toy demo:

x1 = jnp.ones((10, 5, 8))
x2 = jnp.ones((30, 15, 12))
y = jnp.ones((30, 15, 12))

jax.make_jaxpr(
  lambda x1, x2, y: jnp.square(y - lax.scan(lambda c, xs: (jnp.sin(c), jnp.cos(xs)), x1, x2)[1]).mean()
)(x1, x2, y)

Note that I've used 3D inputs, so it can't just be an addition in the carry.

What JAX does at the moment:

{ lambda  ; a b c.
  let _ d = scan[ jaxpr={ lambda  ; a b.
                          let c = sin a
                              d = cos b
                          in (c, d) }
                  length=30
                  linear=(False, False)
                  num_carry=1
                  num_consts=0
                  reverse=False
                  unroll=1 ] a b
      e = sub c d
      f = integer_pow[ y=2 ] e
      g = reduce_sum[ axes=(0, 1, 2) ] f
      h = div g 5400.0
  in (h,) }

What I would hope for it to get turned into:

{ lambda  ; a b c.
  let _ d = scan[ jaxpr={ lambda  ; a b c.
                          let d = sin a
                              e = cos b
                              f = sub c e
                              g = integer_pow[ y=2 ] f
                              h = reduce_sum[ axes=(0, 1) ] g
                          in (d, h) }
                  length=30
                  linear=(False, False, False)
                  num_carry=1
                  num_consts=0
                  reverse=False
                  unroll=1 ] a b c
      e = reduce_sum[ axes=(0,) ] d
      f = div e 5400.0
  in (f,) }

I haven't tested that, but I think that it's correct...

Chillee commented 3 years ago

If I'm understanding correctly, is this the same optimization that KeOps performs? http://www.kernel-operations.io/keops/_auto_benchmarks/index.html

They have a lot of examples of potential users of this kind of optimization.