google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.06k stars 2.66k forks source link

Efficient diag(JtJ) #20600

Open ywrt opened 3 months ago

ywrt commented 3 months ago

As best as I can tell, the most efficient way to get diag(JtJ) is to do jnp.square(primals).T @ jnp.square(cotangents), but this seems to be difficult to cleanly implement in JAX.

I have an utterly terrible way of doing this by using a custom VJP:

@jax.custom_vjp
def matmul(a, b):  return a @ b
def matmul_fwd(a, b):  return a @ b, (a, b)
def matmul_bwd(res, y_bar):  
  a, b = res    
  da = y_bar @ b.T
  # db = a.T @ y_bar # What would actually compute the gradient.
  db = jnp.square(a.T) @ jnp.square(y_bar) # Instead compute JtJ and return that as the gradient.
  return da, db
matmul.defvjp(matmul_fwd, matmul_bwd)

class Test(nn.Module):
  @nn.compact
  def __call__(self, x):
    w1 = self.param('w1', nn.initializers.normal(), (x.shape[-1], 3))
    x = matmul(x, w1) # x @ w1
    x = jax.nn.relu(x)
    w2 = self.param('w2', nn.initializers.normal(), (x.shape[-1], 1))    
    x = matmul(x, w2) # x @ w2    
    return x.sum()

test = Test()
x = jnp.ones((15,32))
params = params.init(jax.random.key(1), x)
grad = jax.grad(test.apply)(params, x)  # Actually computes diag(JtJ) rather than J (!)

Needless to say, this is rather ugly and fragile. So the question: Is there a good way to implement this with comparable efficiency?

I did look at doing per-example gradients and then squaring and summing them, but that seems to be much less efficient: It builds the full outer product in memory before reducing. There doesn't seem to be an optimizer pass that will promote the square operation to the other side of the matrix multiply :(

mattjj commented 3 months ago

Thanks for the question!

I'm not sure, but I wouldn't be surprised if we can't do the maximally efficient thing as a composition of just jax.jvp and jax.vjp (and jax.vmap), and instead we need to write a special structure-exploiting recursion. By that I mean something like in #18902 (and the libraries discussed in the comments which do a much better job with that, though I think the fundamental math is about the same), or like what people have done for computing log det Jacobians efficiently (like Oryx does), or what I believe this 'fast NTK' stuff does. Actually, that latter may be computing something similar to what you want, though they're usually after full JJ' matrices rather than diag(J'J) matrices.

Indeed I think your rule for matmul is representing exactly such a structure-exploiting rule: because the Jacobian of v -> Av is simply A, and because diag(A'A) can be cleverly computed using a Hadamard product A.^2 (and summing), we can avoid ever having to do a matmul to compute the coefficients of diag(A'A).

I want to think about this more when I have more time, but wanted to plant that initial guess.

I'm starting to think that there are several such special autodiff recursions of interest, and we should factor the JAX autodiff machinery to make those things easy to write (i.e. you just write the rules, then the library takes care of all the JAX-internals boilerplate). Then we could build more such things into JAX, and also make it easy to stand up user libraries.

WDYT?

mattjj commented 3 months ago

As you point out, we could play that trick with the whole Jacobian, or rather all the per-example gradients, and maybe we can't get better asymptotic FLOP costs than that. But with a special recursion we'll be doing the computation in the right order, and thus not OOMing for example, rather than relying on the compiler to clean up everything.

ywrt commented 3 months ago

The special recursion absolutely sounds like something that would be very useful. I'll how to think about exactly what that might look like in the more general case.

For the optimization side, it did occur to me that it's fairly natural. Computing the JtJ as:

  def compute_jtj(params, x):
    per_ex_grads = jax.vmap(jax.grad(test.apply), in_axes=(None, 0))(params, x)
    jtj = jax.tree_util.tree_map(lambda x : jnp.square(x).sum(axis=0), per_ex_grads)
    #grads = jax.tree_util.tree_map(lambda x : x.sum(axis=0), per_ex_grads)
    return jtj

This should compile to something like:

  x = jnp.einsum('b u, b j -> b i j', primals, cotangents)
  x = jnp.square(x)
  x = x.sum()

The matrix multiplication there has no contracting dimension, and the inputs are much smaller in size than the output, so lifting the square op to before the matrix mul should be a very easy optimization opportunity (lower FLOPS, lower memory BW, no precision issues as there's no additions). After that, it becomes:

  a = jnp.square(primals)
  b = jnp.square(cotangents)
  x = jnp.einsum('b u, b j -> b i j', a, b)
  x = x.sum()

At which point, removing the sum() is another zero-risk optimization: It's clearly just adding the contracting dimension to the dotgeneral. It saves much memory BW and has no impact on precision.

So it might be the easier thing in the near term is to enable these optimizations? I'm not sure if they'd go in JAX or LAX (probably LAX?). I absolutely agree it would be better not to rely on compiler magic, but these may be such clear wins that not doing them is a bug! :)

mattjj commented 3 months ago

Good point about moving the square through the outer product. I see now what you meant in your first message.

Is that the only optimization you need here?

ywrt commented 3 months ago

Is the conversion of the sum() into an additional contracting dimension for the preceeding dotgeneral an already existing optimization?

If so then yes, lifting the square is the only one remaining.

ywrt commented 3 months ago

Just checked, looks like merging the sum into the dotgeneral does happen:

def f(x, w):
  x = jnp.einsum('b i, b j -> b i j', x, w)
  x = x.sum(axis=0)
  return x
print(jax.jit(f).lower(jnp.ones((64,64)), jnp.ones((64, 64))).compile().as_text())

gives


ENTRY %main.10 (Arg_0.1: f32[64,64], Arg_1.2: f32[64,64]) -> f32[64,64] {
  %Arg_1.2 = f32[64,64]{1,0} parameter(1), sharding={replicated}
  %Arg_0.1 = f32[64,64]{1,0} parameter(0), sharding={replicated}
  ROOT %custom-call = f32[64,64]{1,0} custom-call(f32[64,64]{1,0} %Arg_0.1, f32[64,64]{1,0} %Arg_1.2), custom_call_target="__cublas$gemm", frontend_attributes={fingerprint_before_lhs="f45c017d4cbc60a6d36d764a48114f68"}, metadata={op_name="jit(f)/jit(main)/b i, b j -> b i j/dot_general[dimension_numbers=(((), ()), ((0,), (0,))) precision=None preferred_element_type=float32]"}, backend_config={"alpha_real":1,"alpha_imag":0,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["0"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}
}

which looks like it has folded the sum into a contracting dimension. So yes, just lifting the square is enough to make the per-example gradient avenue efficient.

ywrt commented 3 months ago

Hmm. Before I say that, the compile code for the current per-example computation is a little weird:

  ...
  %select.1 = f32[128,256]{1,0} select(pred[128,256]{1,0} %compare.1, f32[128,256]{1,0} %broadcast.22, f32[128,256]{1,0} %broadcast.23), metadata={op_name="jit(jtj)/jit(main)/vmap(transpose(jvp(Test)))/select_n"}
  %broadcast.21 = f32[512,128,256]{2,1,0} broadcast(f32[128,256]{1,0} %select.1), dimensions={1,2}, metadata={op_name="jit(jtj)/jit(main)/vmap(transpose(vmap(jvp(Test))))/dot_general[dimension_numbers=(((), ()), ((0,), (0,))) precision=None preferred_element_type=float32]"}
  %param_0.5 = f32[512,256]{1,0} parameter(0)
  %broadcast.20 = f32[512,128,256]{2,1,0} broadcast(f32[512,256]{1,0} %param_0.5), dimensions={0,2}, metadata={op_name="jit(jtj)/jit(main)/vmap(transpose(vmap(jvp(Test))))/dot_general[dimension_numbers=(((), ()), ((0,), (0,))) precision=None preferred_element_type=float32]"}
  %multiply.14 = f32[512,128,256]{2,1,0} multiply(f32[512,128,256]{2,1,0} %broadcast.21, f32[512,128,256]{2,1,0} %broadcast.20), metadata={op_name="jit(jtj)/jit(main)/vmap(transpose(vmap(jvp(Test))))/dot_general[dimension_numbers=(((), ()), ((0,), (0,))) precision=None preferred_element_type=float32]"}
  %multiply.13 = f32[512,128,256]{2,1,0} multiply(f32[512,128,256]{2,1,0} %multiply.14, f32[512,128,256]{2,1,0} %multiply.14), metadata={op_name="jit(jtj)/jit(main)/mul"}
  ROOT %reduce.7 = f32[512,128]{1,0} reduce(f32[512,128,256]{2,1,0} %multiply.13, f32[] %constant_9), dimensions={2}, to_apply=%region_0.24

It's actually broadcasting out the primals and tangents before doing an element-wise multiplication, not doing a dotgeneral at all. That seems ... odd? Unless there's some downstream optimization, that's going to use more memory and more memory BW than doing the dotgeneral?

mattjj commented 3 months ago

It would be good to check the jaxpr, but I think you can see from the name metadata that those ops started from dot_generals in the jaxpr.

I'd want to know that the per-example gradient computation looks right in the jaxpr; then it's just a question of how XLA is deciding to lower it (which would be platform-dependent, size-dependent, etc).

ywrt commented 3 months ago

Yes, the jaxpr looks sane.

module @jit_jtj attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<512x128xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg1: tensor<128x1xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg2: tensor<256x512xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (tensor<512x128xf32> {jax.result_info = "['params']['w1']", mhlo.layout_mode = "default"}, tensor<128x1xf32> {jax.result_info = "['params']['w2']", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.dot_general %arg2, %arg0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<256x512xf32>, tensor<512x128xf32>) -> tensor<256x128xf32>
    %1 = call @relu(%0) : (tensor<256x128xf32>) -> tensor<256x128xf32>
    %2 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor<f32>) -> tensor<256x128xf32>
    %4 = stablehlo.compare  GT, %0, %3,  FLOAT : (tensor<256x128xf32>, tensor<256x128xf32>) -> tensor<256x128xi1>
    %5 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %6 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %7 = stablehlo.dot_general %6, %1, contracting_dims = [] x [], precision = [DEFAULT, DEFAULT] : (tensor<1xf32>, tensor<256x128xf32>) -> tensor<1x256x128xf32>
    %8 = stablehlo.transpose %7, dims = [1, 2, 0] : (tensor<1x256x128xf32>) -> tensor<256x128x1xf32>
    %9 = stablehlo.dot_general %6, %arg1, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<1xf32>, tensor<128x1xf32>) -> tensor<128xf32>
    %10 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %11 = stablehlo.broadcast_in_dim %10, dims = [] : (tensor<f32>) -> tensor<128xf32>
    %12 = stablehlo.constant dense<true> : tensor<i1>
    %13 = stablehlo.broadcast_in_dim %12, dims = [] : (tensor<i1>) -> tensor<256x128xi1>
    %14 = stablehlo.compare  EQ, %4, %13,  UNSIGNED : (tensor<256x128xi1>, tensor<256x128xi1>) -> tensor<256x128xi1>
    %15 = stablehlo.broadcast_in_dim %11, dims = [1] : (tensor<128xf32>) -> tensor<256x128xf32>
    %16 = stablehlo.broadcast_in_dim %9, dims = [1] : (tensor<128xf32>) -> tensor<256x128xf32>
    %17 = stablehlo.select %14, %16, %15 : tensor<256x128xi1>, tensor<256x128xf32>
    %18 = stablehlo.dot_general %17, %arg2, batching_dims = [0] x [0], contracting_dims = [] x [], precision = [DEFAULT, DEFAULT] : (tensor<256x128xf32>, tensor<256x512xf32>) -> tensor<256x128x512xf32>
    %19 = stablehlo.transpose %18, dims = [0, 2, 1] : (tensor<256x128x512xf32>) -> tensor<256x512x128xf32>
    %20 = stablehlo.multiply %19, %19 : tensor<256x512x128xf32>
    %21 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %22 = stablehlo.reduce(%20 init: %21) applies stablehlo.add across dimensions = [0] : (tensor<256x512x128xf32>, tensor<f32>) -> tensor<512x128xf32>
    %23 = stablehlo.multiply %8, %8 : tensor<256x128x1xf32>
    %24 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %25 = stablehlo.reduce(%23 init: %24) applies stablehlo.add across dimensions = [0] : (tensor<256x128x1xf32>, tensor<f32>) -> tensor<128x1xf32>
    return %22, %25 : tensor<512x128xf32>, tensor<128x1xf32>
  }
  func.func private @relu(%arg0: tensor<256x128xf32>) -> tensor<256x128xf32> {
    %0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %1 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor<f32>) -> tensor<256x128xf32>
    %2 = stablehlo.maximum %arg0, %1 : tensor<256x128xf32>
    return %2 : tensor<256x128xf32>
  }
}

The square() at %20 needs to lift through the transpose at %19 as well as the dotgeneral at %18

mattjj commented 3 months ago

I asked Blake the compiler guru and he said:

Seems like it is equivalent on reals but definitely not numerically identical so xla probably would not do it It isn't unstable just different numerically We don't know what the user is intending in general So we avoid elementwise reassociation

I'm following up a bit more, but it leaves us in an interesting spot: how would we solve this on the JAX side? Do we try to provide some mechanism to perform this optimization? Or instead of casting this as an optimization problem, do we instead work out some way to express the right computation well in the first place (along the lines of some 'special autodiff' recursion)?

ywrt commented 3 months ago

I must admit it seems a little weird not to optimize this on the compiler side. This optimization reduces FLOP count and memory bandwidth by a factor equal to the batch size (e.g. 1/64th of the mem-BW for a batch size of 64). Maybe I'm misunderstanding the XLA goals?

I take the general point about the elementwise reassociation, but in this case reassociating the square unlocks a huge decrease in the dotgeneral cost.

That aside, it seems easy enough to optimize at the jaxpr level. As far as I know, there's no hooks to add optimizations here? Ideally, this would be a callback from somewhere that takes and returns a Jaxpr. Maybe mlir.jaxpr_subcomp()?

I just spent way too long reading through chunks of the stages framework to see if I could mutate the Lowered() object before calling compile() on it, but it seems ... extremely fraught. It seems I can use jax.jit(...).lower(...)._hlo to get an ir.Module, and then get the assembly language from that as text, and then re-parse it back into a new ir.Module and assign it to _hlo, and then call compile()! This seems to work, but ... ???

mattjj commented 2 months ago

I realized I'm actually confused about what you want to compute... When you write diag(JtJ), by J do you mean the Jacobian of scalar-valued function, like apply in your example? In that case, you'd be taking the diagonal elements of an outer product, i.e. you'd just be getting the elementwise square of gradients...

Can you define J?

mattjj commented 2 months ago

Maybe I'm misunderstanding the XLA goals?

Yeah, essentially XLA doesn't perform optimizations that would change numerical semantics too much, because that may be changing something the user wrote intentionally. You could imagine a compiler with a different philosophy, like "optimize as much as possible, numerics be damned", or a compiler that is even more conservative about numerics, or anything in between.

ywrt commented 2 months ago

You're right, I'm saying diag(JtJ), but it's a little ill-defined here. What I want is $\mathbb{E}[\text{diag}(J^T J)]$, which is approximated by the per-example diag(JtJ) averaged over the batch.

This is analogous to jax.grad returning $\mathbb{E}[J]$, being the per-example gradient averaged over the batch.

NB: $\mathbb{E}[\text{diag}(J^T J)]$ is obviously distinct from $\text{diag}(\mathbb{E}[J]^T \mathbb{E}[J])$, which is what taking diag(g.T @ g) for g = jax.grad(...)(...) would give.

Does that make sense?