facebookresearch / jepa

PyTorch code and models for V-JEPA self-supervised learning from video.
Other
2.53k stars 242 forks source link

FSDP Support #51

Open andrew-bydlon opened 2 months ago

andrew-bydlon commented 2 months ago

This is a bit of a technical challenge and/or question. Both I-JEPA and V-JEPA use DDP and not FSDP. This puts an inherent cap on the size of models that are used, the size of the GPU memory.

I'm wondering if there is any thought being put into the support of JEPA with FSDP. In my mind, the flow would be to

  1. Ensure that the model sharding of the target and context encoder is equivalent.
  2. Update only the sharded parameters on a particular node (could even be a performance improvement vs. DDP).
  3. During forward passes, share the locally updated weights to all nodes.

I attempted to implement something like this on my side, though FSDP seems to shard the parameters a bit sporadically, e.g. not following 1. above.

Any suggestions?

russellhowes commented 2 months ago

Hi Andrew, thanks for reaching out. FSDP support is on our task list, but we haven't implemented it yet.

I'll keep this open, and check in with any updates on our side (or keep an eye out if you end up getting something working 🙂)

andrew-bydlon commented 2 months ago

Thanks Russell. Look forward to hearing more!

I've tried to implement it, but my loss starts creeping up after a few 1000 steps. I hypothesize that fsdp wrapping each module seems to yield different flattened parameters per GPU, but not totally sure. I keep predictor + context in one module, but include the encoder in my auto wrap policy, which is slightly different than this repo. So maybe it would work better here.

Hope the info helps. Will update if something magical happens 😅