stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
147 stars 11 forks source link

Cleaner All Pairs Difference #106

Open 0xc1c4da opened 2 weeks ago

0xc1c4da commented 2 weeks ago

In torch and jax it is possible to perform an all pairs difference using a one liner black magic represented as follows: dt_segment_sum_jax = dA_cumsum_jax[:, :, :, :, None] - dA_cumsum_jax[:, :, :, None, :]

This is performed in the reference implementation of Mamba 2

While the aforementioned code is not human readable nor obvious what it is doing, it was not obvious how to represent the equivalent in Haliax due to a subset constraint, however a potential solution is below:

def test_all_pairs_difference():
    H = Axis("H", 7)
    W = Axis("W", 8)
    D = Axis("D", 9)
    T = Axis("T", 11)

    named1 = hax.random.uniform(PRNGKey(0), (H, W, D, T))
    # making sure this analogue works:
    #dt_segment_sum_jax = dA_cumsum_jax[:, :, :, :, None] - dA_cumsum_jax[:, :, :, None, :]
    named1_diff = named1.broadcast_axis(hax.Axis("T2", 11)) - named1.rename({"T": "T2"})
    named1_diff = named1_diff.rearrange((..., "T", "T2"))
    assert named1_diff.axes == (H, W, D, T, Axis("T2", 11))

    vanilla_diff = named1.array[:, :, :, :, None] - named1.array[:, :, :, None, :]

    assert jnp.all(named1_diff.array == vanilla_diff)

This issue exists provide better support for this kind of operation.

dlwh commented 1 week ago

what do you think about:

with hax.auto_broadcast():
    named1_diff = named1 - named1.rename({"T": "T2"})

Basically the only thing stopping this from working is an explicit check I do to avoid accidentally combining arrays where one isn't a subset of the other.

The other thing I could do is relax the check to be "at least one overlapping axis"

0xc1c4da commented 2 days ago

I think it is certainly cleaner, but I wouldn't remove the explicit check, wouldn't it be better to explicitly disable the check?

dlwh commented 2 days ago

meaning you like with hax.auto_broadcast?