stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
149 stars 11 forks source link

Batched Take #14

Closed dlwh closed 1 year ago

dlwh commented 1 year ago

Fixes #13 and similar where we might want to index using a batched index:

cc @rohan-mehta-1024

 def cross_entropy(logits: hax.NamedArray, labels: hax.NamedArray) -> hax.NamedArray:
        return hax.take(logits, Embed, labels)  # extract log probability of the correct token

    Embed = Axis("Embed", 10)
    Block = Axis("Block", 20)
    Batch = Axis("Batch", 30)
    logits = hax.random.uniform(PRNGKey(0), (Batch, Block, Embed))
    labels = hax.random.randint(PRNGKey(0), (Batch, Block), 0, Embed.size)

    loss = cross_entropy(logits, labels)
    assert loss.axes == (Batch, Block)
    assert jnp.alltrue(loss.array == jnp.take_along_axis(logits.array, labels.array[..., None], axis=-1)[..., 0])

This ends up being fairly complex to handle in the general case, but I think the result is good.