gaasher / I-JEPA

Implementation of I-JEPA from "Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture"
MIT License
254 stars 26 forks source link

Average pooling while performing downstream tasks #13

Open DipanMandal opened 5 months ago

DipanMandal commented 5 months ago

Hi, thanks for such amazing and straight-forward implementation. I had one doubt regarding the implementation in the finetune_IJEPA.py file:

For the classification task, you are applying average pooling on the complete encoding of every output patch that comes from the target encoder. So the output of the encoder is [batch x 196 x 96] and after average pooling it turns into [batch x 198 x 1]. But doesn't that cause information loss for other architectures where the output embedding size is larger e.g. 1280 ( for the official implementation)?

I know that the paper also says that they apply average pooling on the output, but still wanted to know about your perspective on this.