lightly-ai / lightly

A python library for self-supervised learning on images.
https://docs.lightly.ai/self-supervised-learning/
MIT License
2.83k stars 246 forks source link

Malte lig 4894 fix gatherlayer #1531

Closed MalteEbner closed 2 months ago

MalteEbner commented 2 months ago

closes #1528

Description

Fixes the GatherLayer by using the implementation from solo learn.

Tests

Adds a test for the GatherLayer by testing a model using the NTXentLoss criterion. It compares the training behaviour for these two cases and ensures that it is exactly the same:

  1. n_devices=1, batch_size=8
  2. n_devices=2, batch_size=4

The test is in the new file test_dist__gather.py. This is needed, because using a DDPStrategy causes the file to be executed once per device. Before the fix, the test failed.

Next tests

This test only tests the NTXentLoss criterion, the other models need to be tested as well.

Testing the full SimCLR model

I also tried to have similar test when using a SimCLR model. However, it is extremely hard to get exactly the same training when using it.

Randomness causes different behaviour between n_devices=2 and n_devices=1

Results:

Using the SimCLR transform leads to different behaviour when using n_devices=2 compared to only 1 device. Even seeding does not help. This is caused by the different number of samples and thus the different random seeds. E.g.

Thus only removing randomness makes the output of the dataloader the same for the n_devices=2 and n_devices=1 cases.

The same problem also applies to any randomness in the model itself, e.g. in dropout layers.

Batch normalization causes different behaviour between n_devices=2 and n_devices=1

Batch normalization or any other operation using information from other samples in the same batch behaves differently when using n_devices=2 & batch_size=4 compared to n_devices=1 & batch_size=8. The batch normalisation would need to be synchronised as well for this to work. As pointed out by Guarin, we could use SyncBatchNorm to avoid this: https://lightning.ai/docs/pytorch/stable/common/trainer.html#sync-batchnorm

codecov[bot] commented 2 months ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 81.96%. Comparing base (ec9f620) to head (64ff90d).

Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #1531 +/- ## ========================================== + Coverage 81.76% 81.96% +0.20% ========================================== Files 144 144 Lines 6092 6094 +2 ========================================== + Hits 4981 4995 +14 + Misses 1111 1099 -12 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.