lightly-ai / lightly

A python library for self-supervised learning on images.
https://docs.lightly.ai/self-supervised-learning/
MIT License
2.84k stars 246 forks source link

Add AIM Model from Scalable Pre-training of Large Autoregressive Image Models #1479

Closed guarin closed 5 months ago

guarin commented 5 months ago

This PR implements the AIM model proposed in Scalable Pre-training of Large Autoregressive Image Models. The implementation is based on the original code but uses a modified version of the vision transformer from timm as backbone. The backbone is fully compatible with the timm vision transformer and pretrained weights from our backbone should be loadable with the timm vision transformer (state dicts are identical).

The implementation is a best effort. The paper and reference code miss some crucial information. Specifically, the prefix length and detailed description of the MLP architecture for the prediction head are missing. Nevertheless, the current implementation is running and is hopefully a good start. I checked with the authors, the head and prefix masking should be correct now :)

Changes

TODO:

We also have to figure out whether we want to add this to benchmarks/imagenet/vitb16 because the backbone is clearly not vitb16 😅

How was it tested?

For Review

Review is only required for the following files/functions:

The other files/functions have already been reviewed in other PRs but are not yet on master.

codecov[bot] commented 5 months ago

Codecov Report

Attention: 108 lines in your changes are missing coverage. Please review.

Comparison is base (21bd179) 85.56% compared to head (6b93bf5) 84.40%.

Files Patch % Lines
...models/modules/masked_causal_vision_transformer.py 0.00% 58 Missing :warning:
lightly/models/utils.py 38.63% 27 Missing :warning:
lightly/models/modules/heads_timm.py 0.00% 19 Missing :warning:
lightly/models/modules/__init__.py 50.00% 2 Missing :warning:
lightly/__init__.py 80.00% 1 Missing :warning:
lightly/transforms/aim_transform.py 83.33% 1 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #1479 +/- ## ========================================== - Coverage 85.56% 84.40% -1.16% ========================================== Files 136 139 +3 Lines 5680 5777 +97 ========================================== + Hits 4860 4876 +16 - Misses 820 901 +81 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

adamjstewart commented 5 months ago

FWIW, this PR caused a bit of a headache for us in TorchGeo: https://github.com/microsoft/torchgeo/issues/1824

At the moment, the changes here make lightly v1.4.26 incompatible with any version of segmentation-models-pytorch. This isn't necessarily your fault, but it would help if you could check the version of timm available before importing everything else.