google-deepmind / kfac-jax

Second Order Optimization and Curvature Estimation with K-FAC in JAX.
Apache License 2.0
227 stars 17 forks source link

Compatibility with `pallas` attention #244

Open ae-foster opened 1 month ago

ae-foster commented 1 month ago

First, thank you for creating and releasing this invaluable resource.

What I am trying to do

I would like to combine kfax-jax with fused attention from pallas.

As far as I understand, this should theoretically be trivial: the attention operation itself contains no tagged parameters and no tagged losses, so KFAC should simply propagate the forward and backward passes.

Reproducing failure with the current repo

Environment and set-up

Reproduction

scripts/kfac.py

import haiku as hk
import jax
import jax.experimental.pallas.ops.attention as attention
import jax.numpy as jnp
import kfac_jax

if __name__ == "__main__":

    def model(inputs):
        # shape [batch, nodes, features]
        k = hk.Linear(16 * 16)(inputs).reshape((*inputs.shape[:-1], 16, 16))
        attended = attention.mha(k, k, k, None)
        # reduce to shape batch
        y_hat = attended.mean([-1, -2, -3])
        return y_hat

    # The Haiku transformed model
    hk_model = hk.without_apply_rng(hk.transform(model))

    def loss_fn(model_params, model_batch):
        """The loss function to optimize."""
        x, y = model_batch
        preds = hk_model.apply(model_params, x)
        errs = (y - preds) ** 2
        kfac_jax.register_normal_predictive_distribution(errs)
        loss = jnp.mean(errs)

        return loss

    x = jnp.zeros((16, 16, 16))
    y = jnp.zeros(16)

    rng = jax.random.PRNGKey(42)
    rng, rng_init = jax.random.split(rng)
    params = hk_model.init(rng_init, x)

    # KFAC

    # Create the optimizer
    optimizer = kfac_jax.Optimizer(
        value_and_grad_func=jax.value_and_grad(loss_fn),
        l2_reg=0.0,
        value_func_has_aux=False,
        value_func_has_state=False,
        value_func_has_rng=False,
        use_adaptive_learning_rate=True,
        use_adaptive_momentum=False,
        use_adaptive_damping=True,
        initial_damping=1.0,
        multi_device=False,
    )

    rng, rng_opt = jax.random.split(rng)
    opt_state = optimizer.init(params, rng_opt, (x, y))
    params, opt_state, stats = optimizer.step(
        params, opt_state, rng, batch=(x, y), global_step_int=0, momentum=0
    )

This fails with the following error message

Traceback (most recent call last):
  File "./scripts/kfac.py", line 167, in <module>
    opt_state = optimizer.init(params, rng_opt, (x, y))
  File "./extern/kfac-jax/kfac_jax/_src/optimizer.py", line 1023, in init
    return self._init(params, rng, batch, func_state)
  File "./extern/kfac-jax/kfac_jax/_src/utils/staging.py", line 255, in decorated
    outs = jitted_func(instance, *args)
  File "./extern/kfac-jax/kfac_jax/_src/optimizer.py", line 988, in _init
    estimator_state=self.estimator.init(
  File "./extern/kfac-jax/kfac_jax/_src/utils/misc.py", line 296, in wrapped
    return method(instance, *args, **kwargs)
  File "./extern/kfac-jax/kfac_jax/_src/curvature_estimator.py", line 1182, in init
    self.finalize(func_args)
  File "./extern/kfac-jax/kfac_jax/_src/utils/misc.py", line 266, in finalize
    self._finalize(*args, **kwargs)
  File "./extern/kfac-jax/kfac_jax/_src/curvature_estimator.py", line 1167, in _finalize
    self._jaxpr = self._jaxpr_extractor(func_args)
  File "./extern/kfac-jax/kfac_jax/_src/tracer.py", line 459, in get_processed_jaxpr
    closed_jaxpr, _ = retrieve(func_args)
  File "./extern/kfac-jax/kfac_jax/_src/tracer.py", line 425, in retrieve
    processed_jaxpr = ProcessedJaxpr.make_from_func(
  File "./extern/kfac-jax/kfac_jax/_src/tracer.py", line 314, in make_from_func
    func = tgm.auto_register_tags(
  File "./extern/kfac-jax/kfac_jax/_src/tag_graph_matcher.py", line 1614, in auto_register_tags
    graph = make_jax_graph(
  File "./extern/kfac-jax/kfac_jax/_src/tag_graph_matcher.py", line 336, in make_jax_graph
    closed_jaxpr, out_shapes = jax.make_jaxpr(func, return_shape=True)(*func_args)
  File "./extern/kfac-jax/kfac_jax/_src/optimizer.py", line 1633, in value_func
    out, _ = value_and_grad_func(*args, **kwargs)
  File "./scripts/kfac.py", line 25, in loss_fn
    preds = hk_model.apply(model_params, x)
  File "/opt/env/lib/python3.11/site-packages/haiku/_src/multi_transform.py", line 314, in apply_fn
    return f.apply(params, None, *args, **kwargs)
  File "/opt/env/lib/python3.11/site-packages/haiku/_src/transform.py", line 183, in apply_fn
    out, state = f.apply(params, None, *args, **kwargs)
  File "/opt/env/lib/python3.11/site-packages/haiku/_src/transform.py", line 456, in apply_fn
    out = f(*args, **kwargs)
  File "./scripts/kfac.py", line 13, in model
    attended = attention.mha(k, k, k, None)
  File "/opt/env/lib/python3.11/site-packages/jax/experimental/pallas/ops/attention.py", line 287, in _mha_forward
    out, l, m = pl.pallas_call(
  File "/opt/env/lib/python3.11/site-packages/jax/_src/pallas/pallas_call.py", line 589, in wrapped
    out_flat = pallas_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: `JaxprInputEffect` Read<7> does not have corresponding input: Var(id=125959648193344):float32[16,16].
 Equation: a:i32[] b:f32[16,16] c:f32[16] d:f32[16] e:f32[16,16] f:f32[16] g:f32[16] = scan[
  _split_transpose=False
  jaxpr={ lambda ; h:MemRef<None>{float32[16,16]} i:f32[16,16] j:MemRef<None>{float32[16,16]}
      k:MemRef<None>{float32[16,16]} l:f32[16,16] m:MemRef<None>{float32[16,16]}
      n:i32[] o:f32[16,16] p:f32[16] q:f32[16] r:f32[16,16] s:f32[16] t:f32[16]. let
      u:i32[] = add n 1
      v:i32[] = mul n 16
      w:f32[16,16] <- h[v:v+16,:]
      x:f32[16,16] <- k[v:v+16,:]
      y:f32[16,16] = transpose[permutation=(1, 0)] w
      z:f32[16,16] = transpose[permutation=(1, 0)] x
      ba:f32[16,16] = dot_general[
        dimension_numbers=(([1], [0]), ([], []))
        preferred_element_type=float32
      ] i y
      bb:f32[16,16] = dot_general[
        dimension_numbers=(([1], [0]), ([], []))
        preferred_element_type=float32
      ] l y
      bc:f32[16,16] = dot_general[
        dimension_numbers=(([1], [0]), ([], []))
        preferred_element_type=float32
      ] i z
      bd:f32[16,16] = add_any bb bc
      be:f32[16] = reduce_max[axes=(1,)] ba
      bf:f32[16,1] = reshape[dimensions=None new_sizes=(16, 1)] be
      bg:bool[16,16] = eq ba bf
      bh:f32[16,16] = convert_element_type[new_dtype=float32 weak_type=False] bg
      bi:f32[16] = reduce_sum[axes=(1,)] bh
      bj:f32[16,16] = mul bd bh
      bk:f32[16] = reduce_sum[axes=(1,)] bj
      bl:f32[16] = div bk bi
      bm:f32[16] = max p be
      bn:bool[16] = eq p bm
      bo:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
      bp:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 0.0
      bq:f32[16] = select_n bn bp bo
      br:bool[16] = eq be bm
      bs:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 2.0
      bt:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
      bu:f32[16] = select_n br bt bs
      bv:f32[16] = div bq bu
      bw:f32[16] = mul s bv
      bx:bool[16] = eq be bm
      by:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
      bz:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 0.0
      ca:f32[16] = select_n bx bz by
      cb:bool[16] = eq p bm
      cc:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 2.0
      cd:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
      ce:f32[16] = select_n cb cd cc
      cf:f32[16] = div ca ce
      cg:f32[16] = mul bl cf
      ch:f32[16] = add_any bw cg
      ci:f32[16] = sub p bm
      cj:f32[16] = sub s ch
      ck:f32[16] = exp ci
      cl:f32[16] = mul cj ck
      cm:f32[16] = mul ck q
      cn:f32[16] = mul cl q
      co:f32[16] = mul ck t
      cp:f32[16] = add_any cn co
      cq:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] bm
      cr:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] ch
      cs:f32[16,16] = sub ba cq
      ct:f32[16,16] = sub bd cr
      cu:f32[16,16] = exp cs
      cv:f32[16,16] = mul ct cu
      cw:f32[16] = reduce_sum[axes=(1,)] cu
      cx:f32[16] = reduce_sum[axes=(1,)] cv
      cy:f32[16] = add cm cw
      cz:f32[16] = add cp cx
      da:f32[16] = div 1.0 cy
      db:f32[16] = neg cz
      dc:f32[16] = mul db 1.0
      dd:f32[16] = integer_pow[y=-2] cy
      de:f32[16] = mul dc dd
      df:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] da
      dg:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] de
      dh:f32[16,16] = mul cu df
      di:f32[16,16] = mul cv df
      dj:f32[16,16] = mul cu dg
      dk:f32[16,16] = add_any di dj
      dl:f32[16] = mul cm da
      dm:f32[16] = mul cp da
      dn:f32[16] = mul cm de
      do:f32[16] = add_any dm dn
      dp:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] dl
      dq:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] do
      dr:f32[16,16] = mul dp o
      ds:f32[16,16] = mul dq o
      dt:f32[16,16] = mul dp r
      du:f32[16,16] = add_any ds dt
      dv:f32[16,16] <- j[v:v+16,:]
      dw:f32[16,16] <- m[v:v+16,:]
      dx:f32[16,16] = dot_general[
        dimension_numbers=(([1], [0]), ([], []))
        preferred_element_type=float32
      ] dh dv
      dy:f32[16,16] = dot_general[
        dimension_numbers=(([1], [0]), ([], []))
        preferred_element_type=float32
      ] dk dv
      dz:f32[16,16] = dot_general[
        dimension_numbers=(([1], [0]), ([], []))
        preferred_element_type=float32
      ] dh dw
      ea:f32[16,16] = add_any dy dz
      eb:f32[16,16] = add dr dx
      ec:f32[16,16] = add du ea
    in (u, eb, bm, cy, ec, ch, cz) }
  length=1
  linear=(False, False, False, True, True, True, False, False, False, False, True, True, True)
  num_carry=7
  num_consts=6
  reverse=False
  unroll=1
] ed ee ef eg eh ei ej ek el em en eo ep

 Jaxpr: { lambda a:f32[] b:f32[] c:f32[] d:f32[] e:i32[] f:f32[] g:f32[] h:f32[] i:i32[]; j:MemRef<None>{float32[16,16]}
    k:MemRef<None>{float32[16,16]} l:MemRef<None>{float32[16,16]} m:MemRef<None>{float32[16,16]}
    n:MemRef<None>{float32[16]} o:MemRef<None>{float32[16]} p:MemRef<None>{float32[16,16]}
    q:MemRef<None>{float32[16,16]} r:MemRef<None>{float32[16,16]} s:MemRef<None>{float32[16,16]}
    t:MemRef<None>{float32[16]} u:MemRef<None>{float32[16]}. let
    v:i32[] = program_id[axis=0] 
    w:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] a
    x:f32[16] = sub w b
    y:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] c
    z:f32[16,16] = broadcast_in_dim[broadcast_dimensions=() shape=(16, 16)] d
    ba:i32[] = mul v e
    bb:f32[16,16] <- j[ba:ba+16,:]
    bc:f32[16,16] <- p[ba:ba+16,:]
    bd:f32[] = convert_element_type[new_dtype=float32 weak_type=False] f
    be:f32[16,16] = broadcast_in_dim[broadcast_dimensions=() shape=(16, 16)] bd
    bf:f32[] = convert_element_type[new_dtype=float32 weak_type=False] g
    bg:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] bf
    bh:f32[] = convert_element_type[new_dtype=float32 weak_type=False] h
    bi:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] bh
    bj:i32[] bk:f32[16,16] bl:f32[16] bm:f32[16] bn:f32[16,16] bo:f32[16] bp:f32[16] = scan[
      _split_transpose=False
      jaxpr={ lambda ; bq:MemRef<None>{float32[16,16]} br:f32[16,16] bs:MemRef<None>{float32[16,16]}
          bt:MemRef<None>{float32[16,16]} bu:f32[16,16] bv:MemRef<None>{float32[16,16]}
          bw:i32[] bx:f32[16,16] by:f32[16] bz:f32[16] ca:f32[16,16] cb:f32[16] cc:f32[16]. let
          cd:i32[] = add bw 1
          ce:i32[] = mul bw 16
          cf:f32[16,16] <- bq[ce:ce+16,:]
          cg:f32[16,16] <- bt[ce:ce+16,:]
          ch:f32[16,16] = transpose[permutation=(1, 0)] cf
          ci:f32[16,16] = transpose[permutation=(1, 0)] cg
          cj:f32[16,16] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] br ch
          ck:f32[16,16] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] bu ch
          cl:f32[16,16] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] br ci
          cm:f32[16,16] = add_any ck cl
          cn:f32[16] = reduce_max[axes=(1,)] cj
          co:f32[16,1] = reshape[dimensions=None new_sizes=(16, 1)] cn
          cp:bool[16,16] = eq cj co
          cq:f32[16,16] = convert_element_type[
            new_dtype=float32
            weak_type=False
          ] cp
          cr:f32[16] = reduce_sum[axes=(1,)] cq
          cs:f32[16,16] = mul cm cq
          ct:f32[16] = reduce_sum[axes=(1,)] cs
          cu:f32[16] = div ct cr
          cv:f32[16] = max by cn
          cw:bool[16] = eq by cv
          cx:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
          cy:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 0.0
          cz:f32[16] = select_n cw cy cx
          da:bool[16] = eq cn cv
          db:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 2.0
          dc:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
          dd:f32[16] = select_n da dc db
          de:f32[16] = div cz dd
          df:f32[16] = mul cb de
          dg:bool[16] = eq cn cv
          dh:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
          di:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 0.0
          dj:f32[16] = select_n dg di dh
          dk:bool[16] = eq by cv
          dl:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 2.0
          dm:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
          dn:f32[16] = select_n dk dm dl
          do:f32[16] = div dj dn
          dp:f32[16] = mul cu do
          dq:f32[16] = add_any df dp
          dr:f32[16] = sub by cv
          ds:f32[16] = sub cb dq
          dt:f32[16] = exp dr
          du:f32[16] = mul ds dt
          dv:f32[16] = mul dt bz
          dw:f32[16] = mul du bz
          dx:f32[16] = mul dt cc
          dy:f32[16] = add_any dw dx
          dz:f32[16,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(16, 1)
          ] cv
          ea:f32[16,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(16, 1)
          ] dq
          eb:f32[16,16] = sub cj dz
          ec:f32[16,16] = sub cm ea
          ed:f32[16,16] = exp eb
          ee:f32[16,16] = mul ec ed
          ef:f32[16] = reduce_sum[axes=(1,)] ed
          eg:f32[16] = reduce_sum[axes=(1,)] ee
          eh:f32[16] = add dv ef
          ei:f32[16] = add dy eg
          ej:f32[16] = div 1.0 eh
          ek:f32[16] = neg ei
          el:f32[16] = mul ek 1.0
          em:f32[16] = integer_pow[y=-2] eh
          en:f32[16] = mul el em
          eo:f32[16,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(16, 1)
          ] ej
          ep:f32[16,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(16, 1)
          ] en
          eq:f32[16,16] = mul ed eo
          er:f32[16,16] = mul ee eo
          es:f32[16,16] = mul ed ep
          et:f32[16,16] = add_any er es
          eu:f32[16] = mul dv ej
          ev:f32[16] = mul dy ej
          ew:f32[16] = mul dv en
          ex:f32[16] = add_any ev ew
          ey:f32[16,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(16, 1)
          ] eu
          ez:f32[16,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(16, 1)
          ] ex
          fa:f32[16,16] = mul ey bx
          fb:f32[16,16] = mul ez bx
          fc:f32[16,16] = mul ey ca
          fd:f32[16,16] = add_any fb fc
          fe:f32[16,16] <- bs[ce:ce+16,:]
          ff:f32[16,16] <- bv[ce:ce+16,:]
          fg:f32[16,16] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] eq fe
          fh:f32[16,16] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] et fe
          fi:f32[16,16] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] eq ff
          fj:f32[16,16] = add_any fh fi
          fk:f32[16,16] = add fa fg
          fl:f32[16,16] = add fd fj
        in (cd, fk, cv, eh, fl, dq, ei) }
      length=1
      linear=(False, False, False, True, True, True, False, False, False, False, True, True, True)
      num_carry=7
      num_consts=6
      reverse=False
      unroll=1
    ] k bb l q bc r i z x y be bg bi
    fm:f32[16], n[ba:ba+16] <- n[ba:ba+16], bm
    fn:f32[16], t[ba:ba+16] <- t[ba:ba+16], bp
    fo:f32[16], o[ba:ba+16] <- o[ba:ba+16], bl
    fp:f32[16], u[ba:ba+16] <- u[ba:ba+16], bo
    fq:f32[16,16], m[ba:ba+16,:] <- m[ba:ba+16,:], bk
    fr:f32[16,16], s[ba:ba+16,:] <- s[ba:ba+16,:], bn
  in () }

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "./scripts/kfac.py", line 168, in <module>
    params, opt_state, stats = optimizer.step(
                               ^^^^^^^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/optimizer.py", line 1339, in step
    return self._step(params, state, rng, batch, func_state, learning_rate, momentum, damping)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/utils/staging.py", line 255, in decorated
    outs = jitted_func(instance, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/utils/misc.py", line 296, in wrapped
    return method(instance, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/optimizer.py", line 1130, in _step
    state = self._maybe_update_estimator_curvature(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/optimizer.py", line 783, in _maybe_update_estimator_curvature
    return self._maybe_update_estimator_state(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/optimizer.py", line 735, in _maybe_update_estimator_state
    state.estimator_state = lax.cond(
                            ^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/optimizer.py", line 755, in _update_estimator_curvature
    state = self.estimator.update_curvature_matrix_estimate(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/utils/misc.py", line 296, in wrapped
    return method(instance, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/curvature_estimator.py", line 1422, in update_curvature_matrix_estimate
    losses, losses_vjp = self._compute_losses_vjp(func_args)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/utils/misc.py", line 296, in wrapped
    return method(instance, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/curvature_estimator.py", line 1106, in _compute_losses_vjp
    return self._vjp(func_args)
           ^^^^^^^^^^^^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/tracer.py", line 456, in wrapped_transformation
    return f(func_args, *args)
           ^^^^^^^^^^^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/tracer.py", line 871, in _layer_tag_vjp
    _, aux_vjp, losses_inputs = jax.vjp(forward_aux, aux_dict, has_aux=True)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/tracer.py", line 833, in forward_aux
    write(eqn.outvars, tgm.eval_jaxpr_eqn(eqn, input_values))
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "./extern/kfac-jax/kfac_jax/_src/tag_graph_matcher.py", line 72, in eval_jaxpr_eqn
    output = eqn.primitive.bind(*subfuns, *in_values, **bind_params)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/env/lib/python3.11/site-packages/jax/_src/pallas/pallas_call.py", line 257, in _pallas_call_jvp_rule
    jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, nonzero_tangents_with_outputs, [])
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: `JaxprInputEffect` Read<7> does not have corresponding input: Var(id=125959648193344):float32[16,16].
 Equation: a:i32[] b:f32[16,16] c:f32[16] d:f32[16] e:f32[16,16] f:f32[16] g:f32[16] = scan[
  _split_transpose=False
  jaxpr={ lambda ; h:MemRef<None>{float32[16,16]} i:f32[16,16] j:MemRef<None>{float32[16,16]}
      k:MemRef<None>{float32[16,16]} l:f32[16,16] m:MemRef<None>{float32[16,16]}
      n:i32[] o:f32[16,16] p:f32[16] q:f32[16] r:f32[16,16] s:f32[16] t:f32[16]. let
      u:i32[] = add n 1
      v:i32[] = mul n 16
      w:f32[16,16] <- h[v:v+16,:]
      x:f32[16,16] <- k[v:v+16,:]
      y:f32[16,16] = transpose[permutation=(1, 0)] w
      z:f32[16,16] = transpose[permutation=(1, 0)] x
      ba:f32[16,16] = dot_general[
        dimension_numbers=(([1], [0]), ([], []))
        preferred_element_type=float32
      ] i y
      bb:f32[16,16] = dot_general[
        dimension_numbers=(([1], [0]), ([], []))
        preferred_element_type=float32
      ] l y
      bc:f32[16,16] = dot_general[
        dimension_numbers=(([1], [0]), ([], []))
        preferred_element_type=float32
      ] i z
      bd:f32[16,16] = add_any bb bc
      be:f32[16] = reduce_max[axes=(1,)] ba
      bf:f32[16,1] = reshape[dimensions=None new_sizes=(16, 1)] be
      bg:bool[16,16] = eq ba bf
      bh:f32[16,16] = convert_element_type[new_dtype=float32 weak_type=False] bg
      bi:f32[16] = reduce_sum[axes=(1,)] bh
      bj:f32[16,16] = mul bd bh
      bk:f32[16] = reduce_sum[axes=(1,)] bj
      bl:f32[16] = div bk bi
      bm:f32[16] = max p be
      bn:bool[16] = eq p bm
      bo:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
      bp:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 0.0
      bq:f32[16] = select_n bn bp bo
      br:bool[16] = eq be bm
      bs:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 2.0
      bt:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
      bu:f32[16] = select_n br bt bs
      bv:f32[16] = div bq bu
      bw:f32[16] = mul s bv
      bx:bool[16] = eq be bm
      by:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
      bz:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 0.0
      ca:f32[16] = select_n bx bz by
      cb:bool[16] = eq p bm
      cc:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 2.0
      cd:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
      ce:f32[16] = select_n cb cd cc
      cf:f32[16] = div ca ce
      cg:f32[16] = mul bl cf
      ch:f32[16] = add_any bw cg
      ci:f32[16] = sub p bm
      cj:f32[16] = sub s ch
      ck:f32[16] = exp ci
      cl:f32[16] = mul cj ck
      cm:f32[16] = mul ck q
      cn:f32[16] = mul cl q
      co:f32[16] = mul ck t
      cp:f32[16] = add_any cn co
      cq:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] bm
      cr:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] ch
      cs:f32[16,16] = sub ba cq
      ct:f32[16,16] = sub bd cr
      cu:f32[16,16] = exp cs
      cv:f32[16,16] = mul ct cu
      cw:f32[16] = reduce_sum[axes=(1,)] cu
      cx:f32[16] = reduce_sum[axes=(1,)] cv
      cy:f32[16] = add cm cw
      cz:f32[16] = add cp cx
      da:f32[16] = div 1.0 cy
      db:f32[16] = neg cz
      dc:f32[16] = mul db 1.0
      dd:f32[16] = integer_pow[y=-2] cy
      de:f32[16] = mul dc dd
      df:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] da
      dg:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] de
      dh:f32[16,16] = mul cu df
      di:f32[16,16] = mul cv df
      dj:f32[16,16] = mul cu dg
      dk:f32[16,16] = add_any di dj
      dl:f32[16] = mul cm da
      dm:f32[16] = mul cp da
      dn:f32[16] = mul cm de
      do:f32[16] = add_any dm dn
      dp:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] dl
      dq:f32[16,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(16, 1)] do
      dr:f32[16,16] = mul dp o
      ds:f32[16,16] = mul dq o
      dt:f32[16,16] = mul dp r
      du:f32[16,16] = add_any ds dt
      dv:f32[16,16] <- j[v:v+16,:]
      dw:f32[16,16] <- m[v:v+16,:]
      dx:f32[16,16] = dot_general[
        dimension_numbers=(([1], [0]), ([], []))
        preferred_element_type=float32
      ] dh dv
      dy:f32[16,16] = dot_general[
        dimension_numbers=(([1], [0]), ([], []))
        preferred_element_type=float32
      ] dk dv
      dz:f32[16,16] = dot_general[
        dimension_numbers=(([1], [0]), ([], []))
        preferred_element_type=float32
      ] dh dw
      ea:f32[16,16] = add_any dy dz
      eb:f32[16,16] = add dr dx
      ec:f32[16,16] = add du ea
    in (u, eb, bm, cy, ec, ch, cz) }
  length=1
  linear=(False, False, False, True, True, True, False, False, False, False, True, True, True)
  num_carry=7
  num_consts=6
  reverse=False
  unroll=1
] ed ee ef eg eh ei ej ek el em en eo ep

 Jaxpr: { lambda a:f32[] b:f32[] c:f32[] d:f32[] e:i32[] f:f32[] g:f32[] h:f32[] i:i32[]; j:MemRef<None>{float32[16,16]}
    k:MemRef<None>{float32[16,16]} l:MemRef<None>{float32[16,16]} m:MemRef<None>{float32[16,16]}
    n:MemRef<None>{float32[16]} o:MemRef<None>{float32[16]} p:MemRef<None>{float32[16,16]}
    q:MemRef<None>{float32[16,16]} r:MemRef<None>{float32[16,16]} s:MemRef<None>{float32[16,16]}
    t:MemRef<None>{float32[16]} u:MemRef<None>{float32[16]}. let
    v:i32[] = program_id[axis=0] 
    w:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] a
    x:f32[16] = sub w b
    y:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] c
    z:f32[16,16] = broadcast_in_dim[broadcast_dimensions=() shape=(16, 16)] d
    ba:i32[] = mul v e
    bb:f32[16,16] <- j[ba:ba+16,:]
    bc:f32[16,16] <- p[ba:ba+16,:]
    bd:f32[] = convert_element_type[new_dtype=float32 weak_type=False] f
    be:f32[16,16] = broadcast_in_dim[broadcast_dimensions=() shape=(16, 16)] bd
    bf:f32[] = convert_element_type[new_dtype=float32 weak_type=False] g
    bg:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] bf
    bh:f32[] = convert_element_type[new_dtype=float32 weak_type=False] h
    bi:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] bh
    bj:i32[] bk:f32[16,16] bl:f32[16] bm:f32[16] bn:f32[16,16] bo:f32[16] bp:f32[16] = scan[
      _split_transpose=False
      jaxpr={ lambda ; bq:MemRef<None>{float32[16,16]} br:f32[16,16] bs:MemRef<None>{float32[16,16]}
          bt:MemRef<None>{float32[16,16]} bu:f32[16,16] bv:MemRef<None>{float32[16,16]}
          bw:i32[] bx:f32[16,16] by:f32[16] bz:f32[16] ca:f32[16,16] cb:f32[16] cc:f32[16]. let
          cd:i32[] = add bw 1
          ce:i32[] = mul bw 16
          cf:f32[16,16] <- bq[ce:ce+16,:]
          cg:f32[16,16] <- bt[ce:ce+16,:]
          ch:f32[16,16] = transpose[permutation=(1, 0)] cf
          ci:f32[16,16] = transpose[permutation=(1, 0)] cg
          cj:f32[16,16] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] br ch
          ck:f32[16,16] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] bu ch
          cl:f32[16,16] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] br ci
          cm:f32[16,16] = add_any ck cl
          cn:f32[16] = reduce_max[axes=(1,)] cj
          co:f32[16,1] = reshape[dimensions=None new_sizes=(16, 1)] cn
          cp:bool[16,16] = eq cj co
          cq:f32[16,16] = convert_element_type[
            new_dtype=float32
            weak_type=False
          ] cp
          cr:f32[16] = reduce_sum[axes=(1,)] cq
          cs:f32[16,16] = mul cm cq
          ct:f32[16] = reduce_sum[axes=(1,)] cs
          cu:f32[16] = div ct cr
          cv:f32[16] = max by cn
          cw:bool[16] = eq by cv
          cx:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
          cy:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 0.0
          cz:f32[16] = select_n cw cy cx
          da:bool[16] = eq cn cv
          db:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 2.0
          dc:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
          dd:f32[16] = select_n da dc db
          de:f32[16] = div cz dd
          df:f32[16] = mul cb de
          dg:bool[16] = eq cn cv
          dh:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
          di:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 0.0
          dj:f32[16] = select_n dg di dh
          dk:bool[16] = eq by cv
          dl:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 2.0
          dm:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
          dn:f32[16] = select_n dk dm dl
          do:f32[16] = div dj dn
          dp:f32[16] = mul cu do
          dq:f32[16] = add_any df dp
          dr:f32[16] = sub by cv
          ds:f32[16] = sub cb dq
          dt:f32[16] = exp dr
          du:f32[16] = mul ds dt
          dv:f32[16] = mul dt bz
          dw:f32[16] = mul du bz
          dx:f32[16] = mul dt cc
          dy:f32[16] = add_any dw dx
          dz:f32[16,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(16, 1)
          ] cv
          ea:f32[16,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(16, 1)
          ] dq
          eb:f32[16,16] = sub cj dz
          ec:f32[16,16] = sub cm ea
          ed:f32[16,16] = exp eb
          ee:f32[16,16] = mul ec ed
          ef:f32[16] = reduce_sum[axes=(1,)] ed
          eg:f32[16] = reduce_sum[axes=(1,)] ee
          eh:f32[16] = add dv ef
          ei:f32[16] = add dy eg
          ej:f32[16] = div 1.0 eh
          ek:f32[16] = neg ei
          el:f32[16] = mul ek 1.0
          em:f32[16] = integer_pow[y=-2] eh
          en:f32[16] = mul el em
          eo:f32[16,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(16, 1)
          ] ej
          ep:f32[16,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(16, 1)
          ] en
          eq:f32[16,16] = mul ed eo
          er:f32[16,16] = mul ee eo
          es:f32[16,16] = mul ed ep
          et:f32[16,16] = add_any er es
          eu:f32[16] = mul dv ej
          ev:f32[16] = mul dy ej
          ew:f32[16] = mul dv en
          ex:f32[16] = add_any ev ew
          ey:f32[16,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(16, 1)
          ] eu
          ez:f32[16,1] = broadcast_in_dim[
            broadcast_dimensions=(0,)
            shape=(16, 1)
          ] ex
          fa:f32[16,16] = mul ey bx
          fb:f32[16,16] = mul ez bx
          fc:f32[16,16] = mul ey ca
          fd:f32[16,16] = add_any fb fc
          fe:f32[16,16] <- bs[ce:ce+16,:]
          ff:f32[16,16] <- bv[ce:ce+16,:]
          fg:f32[16,16] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] eq fe
          fh:f32[16,16] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] et fe
          fi:f32[16,16] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] eq ff
          fj:f32[16,16] = add_any fh fi
          fk:f32[16,16] = add fa fg
          fl:f32[16,16] = add fd fj
        in (cd, fk, cv, eh, fl, dq, ei) }
      length=1
      linear=(False, False, False, True, True, True, False, False, False, False, True, True, True)
      num_carry=7
      num_consts=6
      reverse=False
      unroll=1
    ] k bb l q bc r i z x y be bg bi
    fm:f32[16], n[ba:ba+16] <- n[ba:ba+16], bm
    fn:f32[16], t[ba:ba+16] <- t[ba:ba+16], bp
    fo:f32[16], o[ba:ba+16] <- o[ba:ba+16], bl
    fp:f32[16], u[ba:ba+16] <- u[ba:ba+16], bo
    fq:f32[16,16], m[ba:ba+16,:] <- m[ba:ba+16,:], bk
    fr:f32[16,16], s[ba:ba+16,:] <- s[ba:ba+16,:], bn
  in () }
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

My own attempts to investigate

We can see that the issue arises when running vjp on the manipulated version of the model. To try and diagnose the issue, I tried a minimal re-implementation of this part of the KFAC algorithm

# With the same model, etc from scripts/kfac.py above

import functools

primal_func_args = [params, (x, y)]

def read_env(
    env,
    variables,
):
    """Reads from the variable-to-array environment during tracing."""
    result = []
    assert isinstance(variables, list)
    for v in variables:
        if isinstance(v, jax.core.Literal):
            # Literals are values baked into the Jaxpr
            result.append(v.val)
        else:
            result.append(env[v])
    return result

def write_env(
    env,
    variables,
    values,
) -> None:
    """Writes to the variable-to-array environment during tracing."""
    assert len(variables) == len(values)
    for variables, val in zip(variables, values):
        env[variables] = val

def eval_jaxpr_eqn(eqn, in_values):
    """Computes the outputs of the given Jaxpr equation."""

    subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)

    output = eqn.primitive.bind(*subfuns, *in_values, **bind_params)

    if not isinstance(output, list):
        return [output]
    else:
        return output

processed_jaxpr = jax.make_jaxpr(jax.value_and_grad(loss_fn))(params, (x, y))

layer_input_vars = list(u for eqn in processed_jaxpr.eqns for u in eqn.invars)

def forward():
    """Computes the values of all inputs to all **layer** tags."""

    own_func_args = primal_func_args

    # Mapping from variable -> value
    env = {}
    read = functools.partial(read_env, env)
    write = functools.partial(write_env, env)

    # Bind args and consts to environment
    write(processed_jaxpr.jaxpr.invars, jax.tree_util.tree_leaves(own_func_args))
    write(processed_jaxpr.jaxpr.constvars, processed_jaxpr.consts)

    # Loop through equations and evaluate them
    for eqn in processed_jaxpr.jaxpr.eqns:

        write(eqn.outvars, eval_jaxpr_eqn(eqn, read(eqn.invars)))

    return tuple(read(layer_input_vars))

input_values = forward()

def forward_aux(aux):

    own_func_args = primal_func_args

    # Mapping from variable -> value
    env = {}
    read = functools.partial(read_env, env)

    def write(variables: list[jax.core.Var], values) -> None:
        # if not isinstance(variables, list):
        #   variables = [variables]
        write_env(env, variables, values)

        for v in variables:
            if not isinstance(v, jax.core.Literal) and v in aux:
                env[v] = env[v] + aux[v]

    # Bind args and consts to environment
    write(processed_jaxpr.jaxpr.invars, jax.tree_util.tree_leaves(own_func_args))

    write(processed_jaxpr.jaxpr.constvars, processed_jaxpr.consts)

    # Loop through equations and evaluate primitives using `bind`
    losses_p_dependants = []
    losses_inputs_values = []

    for eqn in processed_jaxpr.jaxpr.eqns:

        input_values = read(eqn.invars)
        out = eval_jaxpr_eqn(eqn, input_values)
        write(eqn.outvars, out)

        losses_inputs_values.append(tuple(input_values))

    return tuple(losses_p_dependants), tuple(losses_inputs_values)

aux_dict = jax.tree_util.tree_map(jnp.zeros_like, input_values)
my_outputs, my_outputs_aux = forward_aux(aux_dict)
_, aux_vjp, losses_inputs = jax.vjp(forward_aux, aux_dict, has_aux=True)
print("It worked.")

This runs without error. This reimplementation was based on my reading and understanding of the kfac-jax codebase and might possibly miss something important. Of course, it is missing the loss and layer tagging part; I had hoped that wasn't relevant.

Request for assistance

I would really appreciate your advice on this task. Specifically

  1. what is the root cause of the current failure?
  2. would it be trivial to fix kfac-jax to work with pallas attention? I would be happy to help work on a fix with guidance on where to look
  3. if it is not trivial, would it be possible to hack kfac-jax to work specifically with the attention operation, assuming that this operation contains no layer tags and no loss tags?
botev commented 1 month ago

Hi,

So this is relatively complicated issue. The reason is that the current code would be incompatible with anything that uses custom_vjp which the pallas attention does. This is due to the fact that internally kfac-jax differentiates value_and_grad_func(*args)[0] -- e.g. extracting the forward pass from the user provided value_and_grad_func. However, as I have now found out - the jax.value_and_grad strips away the custom_vjp primitive. Hence, at the point where kfac-jax differentiates the function there is no custom_vjp and instead there is just mha_forward. For this we might need to think about, how to address it. A few workarounds, while we decide what is best:

  1. Extend the optimizer, such that it accepts the forward_func, and internally creates a value_and_grad_func with jax.value_and_grad. This way you can pass the real forward_func to the curvature estimator.

  2. Instead of using jax.value_and_grad pass a function that returns f(x), jax.grad(f)(x). This way the forward pass that kfac-jax will extract will contain the custom_vjp.

PS: I tried 1) and it works, however there is one more complication that arises later - KFAC uses forward mode autodiff, while you can not apply forward mode AD to a custom_vjp function.

To address this my understanding is, that because Jax currently does not support custom_vjp and custom_jvp together, and a pallas_call would not work with the automatic derivation of vjp for custom_jvp, what needs to happen is to create a new jax.core.Primitive for the pallas_call which to wrap correctly the vjp and jvp. Also someone to implement the jvp rule for the attention Pallas call. Atm I don't think I have the capacity to do that though.

ae-foster commented 1 month ago

Thanks @botev , this has given me a good direction to look in.

I have hacked my kfac-jax as per your suggestion 1), and I can now get the new error

TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.

presumably the same as you got.

I can take a look at implementing a pallas kernel for custom_jvp for attention. Will keep you posted

botev commented 1 month ago

Adding a custom_jvp might be unfeasible as the current way that Jax works it allows for only one custom decorator. Most likely the whole attention call needs to be wrapped as jax primitive. But if you get anything to work let me know.

ae-foster commented 1 month ago

Could I (temporarily) drop the custom vjp and have only a custom jvp? We could at least see if it works

botev commented 1 month ago

you can try? I think jax might be able to infer the vjp from the jvp?

ae-foster commented 1 month ago

Quick update on this. I tried implementing a custom_jvp on a simple pallas kernel for addition instead of full blown attention. Unfortunately, we hit the issue that transposition rules are not implemented for pallas_call and hence we cannot run vjp, see https://github.com/google/jax/issues/19146

Fortunately, it seems the jax folks are working on related issues although things are not yet fully documented / released. I found this snippet based on https://github.com/google/jax/issues/9129 . Importantly, custom_transpose is available but not fully documented in the latest jax.

The simplest example of this is

def add_vectors_kernel(x_ref, y_ref, o_ref):
    x, y = x_ref[...], y_ref[...]
    o_ref[...] = x + y

@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
    return pl.pallas_call(add_vectors_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype))(
        x, y
    )

@jax.jit
def add_vectors_fwd(x: jax.Array, y: jax.Array) -> jax.Array:
    s = pl.pallas_call(add_vectors_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype))(
        x, y
    )
    return s, s

@jax.jit
def add_vectors_tangent(residuals, tangents, out_avals=None) -> jax.Array:
    x_dot, y_dot = tangents
    grad_out = pl.pallas_call(
        add_vectors_kernel, out_shape=jax.ShapeDtypeStruct(x_dot.shape, x_dot.dtype)
    )(x_dot, y_dot)
    return grad_out

@jax.jit
def add_vectors_bwd(residuals, cotangents) -> jax.Array:
    return (cotangents, cotangents)

def custom_vjp_by_custom_transpose(fun, tangent, fwd, bwd):
    fun = custom_jvp(fun)

    @fun.defjvp
    def jvp(primals, tangents):
        outs, residuals = fwd(*primals)
        tan_out_types = tree_map(lambda o: core.get_aval(o).at_least_vspace(), outs)
        tan_fn = custom_transpose(tangent)
        tan_fn.def_transpose(bwd)
        return outs, tan_fn(tan_out_types, residuals, tangents)

    return fun

pallas_add = custom_vjp_by_custom_transpose(
    add_vectors, add_vectors_tangent, add_vectors_fwd, add_vectors_bwd
)

This results in a function that can be differentiated with both jvp and vjp. If in my earlier script, I replace

attended = attention.mha(k, k, k, None)

with

attended = pallas_add(k, k)

then I can successfully run a step of KFAC :)

We may be on the right path.

ae-foster commented 1 month ago

I can confirm that the above approach also works with attention. I guess the only change needed from the kfac-jax end then is the ability to directly input the value function so that the custom_jvp instruction is preserved.

botev commented 1 month ago

Can you try replacing with:

def custom_vjp_by_custom_transpose(fun, tangent, fwd, bwd):
    fun = custom_jvp(fun)

    @fun.defjvp
    def jvp(primals, tangents):
        _, residuals = fwd(*primals)
        outs = fun(*primals)  # Make the primal output a `custom_jvp` call.
        tan_out_types = tree_map(lambda o: core.get_aval(o).at_least_vspace(), outs)
        tan_fn = custom_transpose(tangent)
        tan_fn.def_transpose(bwd)
        return outs, tan_fn(tan_out_types, residuals, tangents)

    return fun
ae-foster commented 1 month ago

Hi @botev Thanks for your tip above about calling fun itself inside the custom_jvp rule. We tried this out and it works with the old convert_value_and_grad_to_value pattern.

Some further updates on migrating attention from our testing script into our real code

  1. Unfortunately, the custom_tranpose pattern doesn't work in more complex settings because it is not compatible with vmap. (This is a known issue with that feature, which is probably why it's not documented. Seems like there is a new suggestion for how to combine jvp and vjp https://github.com/google/jax/pull/22457 )
  2. However, there was an important difference between the code snippet I included above and the real code. In particular, kfac-jax by default does not correctly register dense layers with batching dimensions. In our production code, we have extra patterns to fix this
  3. That meant that in the code snippet I provided, KFAC was treating our linear layers as orphans, requiring ImplicitCurvature estimation that needs jvp and vjp
  4. If the layer is correctly registered, however, it transpires that KFAC only ever calls vjp on the model. That is, we do not actually need to patch the attention layer to support jvp if we only have dense layers in the model.
botev commented 1 week ago

Thanks for the update. Unfortunately 1. is in the hands of the Jax team and not much that we can do about it.

  1. and 3. should be addressed by #263 .

For 4 I think it depends if you are using any of the self-tuning part of the optimizer, since the code here does use jvp.