Closed dlwh closed 1 year ago
Fixes #13 and similar where we might want to index using a batched index:
cc @rohan-mehta-1024
def cross_entropy(logits: hax.NamedArray, labels: hax.NamedArray) -> hax.NamedArray: return hax.take(logits, Embed, labels) # extract log probability of the correct token Embed = Axis("Embed", 10) Block = Axis("Block", 20) Batch = Axis("Batch", 30) logits = hax.random.uniform(PRNGKey(0), (Batch, Block, Embed)) labels = hax.random.randint(PRNGKey(0), (Batch, Block), 0, Embed.size) loss = cross_entropy(logits, labels) assert loss.axes == (Batch, Block) assert jnp.alltrue(loss.array == jnp.take_along_axis(logits.array, labels.array[..., None], axis=-1)[..., 0])
This ends up being fairly complex to handle in the general case, but I think the result is good.
Fixes #13 and similar where we might want to index using a batched index:
cc @rohan-mehta-1024
This ends up being fairly complex to handle in the general case, but I think the result is good.