stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
140 stars 9 forks source link

named arrays and `eqxi.while_loop(mode=checkpointed)` do not get a long #66

Closed dlwh closed 4 months ago

dlwh commented 5 months ago

the shape checks cause problems