stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
140 stars 9 forks source link

auto_shard inside zeros etc #62

Open dlwh opened 5 months ago

dlwh commented 5 months ago

I'm not sure we should do zeros_like (the sharding of the input isn't accessible in general, so we can either do nothing or auto_shard), but it seems like zeros etc should shard per the axis mapping.