Open dlwh opened 8 months ago
I talked to the JAX folks and JAX is willing to do this sharding inside a jit boundary, just not outside. So we could pad on the outside and immediately slice inside the boundary. apparently PAX does it this way?
had another request so bumping priority
@blahBlahhhJ's #588 should solve this. We might need to improve docs though
Currently our FSDP implementation uses JAX's sharding stuff, which requires that the embed axis be divisible by the number of devices (or really data axis size)
Usually this is fine, but recently @ahmeda14960 wanted to use 5 GPUs, but he couldn't. I think PyTorch is willing to do this?
At the minimum we should catch the JAX exception and explain the error.