stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
141 stars 9 forks source link

Add `ravel` / `flatten` #18

Closed dlwh closed 10 months ago

dlwh commented 11 months ago

Pretty sure the syntax should be

a.ravel(new_axis_name: AxisSelector)
rohan-mehta-1024 commented 10 months ago

I was taking a look at this and wrote up a simple implementation that just calls flatten_axes. However I saw this comment in the code which seems to have a different type signature

TODO: implement ravel. Can only do if we either ask for an axis or add ProductAxis or something
 def ravel(self, order='C') -> Any:

Which should the method follow? And should there be support for the optional order argument?

dlwh commented 10 months ago

don't worry about that old todo. I had a very different vision early on in Haliax and that was what I was thinking at the time.

Thanks!

dlwh commented 10 months ago

Fixed in #33 . Thanks @rohan-mehta-1024 !