stanford-crfm / levanter

Legible, Scalable, Reproducible Foundation Models with Named Tensors and Jax
https://levanter.readthedocs.io/en/latest/
Apache License 2.0
501 stars 80 forks source link

FSDP with an odd/weird number of GPUs. #429

Open dlwh opened 8 months ago

dlwh commented 8 months ago

Currently our FSDP implementation uses JAX's sharding stuff, which requires that the embed axis be divisible by the number of devices (or really data axis size)

Usually this is fine, but recently @ahmeda14960 wanted to use 5 GPUs, but he couldn't. I think PyTorch is willing to do this?

At the minimum we should catch the JAX exception and explain the error.

dlwh commented 8 months ago

I talked to the JAX folks and JAX is willing to do this sharding inside a jit boundary, just not outside. So we could pad on the outside and immediately slice inside the boundary. apparently PAX does it this way?

dlwh commented 7 months ago

had another request so bumping priority

dlwh commented 4 months ago

@blahBlahhhJ's #588 should solve this. We might need to improve docs though