stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
154 stars 11 forks source link

Simplify sharding #79

Closed dlwh closed 8 months ago

dlwh commented 8 months ago

Based on discussion in https://github.com/patrick-kidger/equinox/issues/688, we can dramatically simplify this logic now