gaasher / I-JEPA

Implementation of I-JEPA from "Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture"
MIT License
249 stars 24 forks source link

Using same context and targets for all the images in a batch #6

Open jhairgallardo opened 1 year ago

jhairgallardo commented 1 year ago

Hi! I see that you use the same context and targets indices for all the images in a batch. I guess doing that allows you to have the same size tensor for each image when passing through the predictor (so it can be stacked). I saw that here (a different implementation) they used torch.vmap to use different context and target indices for each image, but it is not clear how I could use that in your code. They claim that the loss increased exponentially if you use the same for each image in a batch. (In some of my experiments, I also have seen this, but the network performed well on linear evaluation at the end). Do you have any ideas on how to make the context and targets indices different for each image in a batch for your code?

gaasher commented 1 year ago

I feel like the easiest way to do image-level context/target indices would actually be to calculate the context/target blocks in the dataset given that the dataset already fetches one image at a time. You could pre-calculate lines 87-115 and 118-135 in the dataset getitem . Then, you would just have to calculate the target block embedding in the same way as lines 154-156. Unfortunately I don't have the time to do this in the near-term, however, I hope this comment helps and feel free to keep asking questions and I'll answer them to the best of my ability.