Closed rohan-mehta-1024 closed 11 months ago
It is a fundamental requirement that a named array have only one axis of each name: that's how axes are referenced. That said, this is a bug: it should broadcast the batch axis so that this works. I'll try to fix it now
(i dunno if it's just a for a test, but the take should probably be over the Vocab axis if you're doing language modeling?)
@rohan-mehta-1024 should be fixed. lmk if not
Yes, this works perfectly, thank you! (I'm making sure everything works in a simplified bigram model setting first before doing a full transformer architecture so the vocab size is equal to the embedding size here, which is a little confusing).
Is there a fundamental limitation of the way named tensors are implemented that requires unique axes names, or is this something that could be changed in the future? E.g.,
here I have two named arrays, logits and labels, which overlap in the Batch and Block dimensions. So to use hax.take, I would have to first rename the axes of labels which detracts from the point of the code a little bit.
Is there any way to get around this?