Closed dlwh closed 9 months ago
lots of JAX stuff works with only a single batch axis, and other custom kernels are easier to write if there's a single batch axis. On TPU, this is basically free (assuming the batch axes aren't at the beginning)
fixed in dev. It's called haliax.core.flatten_all_axes_but
haliax.core.flatten_all_axes_but
lots of JAX stuff works with only a single batch axis, and other custom kernels are easier to write if there's a single batch axis. On TPU, this is basically free (assuming the batch axes aren't at the beginning)