stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
149 stars 11 forks source link

Add replica axes #87

Closed blahBlahhhJ closed 4 months ago

blahBlahhhJ commented 4 months ago

to implement pure dp in levanter

dlwh commented 4 months ago

we should probably just move these to levanter. If you bump the haliax dependency in Levanter this should auto-start working once the build is published