stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
141 stars 9 forks source link

Why do all axes have to be unique? #13

Closed rohan-mehta-1024 closed 11 months ago

rohan-mehta-1024 commented 11 months ago

Is there a fundamental limitation of the way named tensors are implemented that requires unique axes names, or is this something that could be changed in the future? E.g.,

def cross_entropy(logits: hax.NamedArray, labels: hax.NamedArray) -> hax.NamedArray:
    """Compute the cross_entropy loss between a batch of logits and labels"""
    preds = hax.nn.log_softmax(logits, axis=Embed)
    loss = hax.take(preds, Embed, labels) # extract log probability of the correct token
    return -hax.mean(loss)

here I have two named arrays, logits and labels, which overlap in the Batch and Block dimensions. So to use hax.take, I would have to first rename the axes of labels which detracts from the point of the code a little bit.

def cross_entropy(logits: hax.NamedArray, labels: hax.NamedArray) -> hax.NamedArray:
    """Compute the cross_entropy loss between a batch of logits and labels"""
    preds = hax.nn.log_softmax(logits, axis=Embed)
    labels = hax.rename(labels, {Batch: "label_batch", Block: "label_block"}) # haliax requires unique axes
    loss = hax.take(preds, Embed, labels) # extract log probability of the correct token
    return -hax.mean(loss)

Is there any way to get around this?

dlwh commented 11 months ago

It is a fundamental requirement that a named array have only one axis of each name: that's how axes are referenced. That said, this is a bug: it should broadcast the batch axis so that this works. I'll try to fix it now

(i dunno if it's just a for a test, but the take should probably be over the Vocab axis if you're doing language modeling?)

dlwh commented 11 months ago

@rohan-mehta-1024 should be fixed. lmk if not

rohan-mehta-1024 commented 11 months ago

Yes, this works perfectly, thank you! (I'm making sure everything works in a simplified bigram model setting first before doing a full transformer architecture so the vocab size is equal to the embedding size here, which is a little confusing).