stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
154 stars 11 forks source link

Make a helper function to squash/unsquash all axes (except some) into a single batch axis #67

Closed dlwh closed 9 months ago

dlwh commented 9 months ago

lots of JAX stuff works with only a single batch axis, and other custom kernels are easier to write if there's a single batch axis. On TPU, this is basically free (assuming the batch axes aren't at the beginning)

dlwh commented 9 months ago

fixed in dev. It's called haliax.core.flatten_all_axes_but