facebookresearch / dino

PyTorch code for Vision Transformers training with the Self-Supervised learning method DINO
Apache License 2.0
6.26k stars 906 forks source link

model collapse after a few steps #43

Closed Doom9234 closed 3 years ago

Doom9234 commented 3 years ago

I use custom data to train DINO, the model seems collapsed after a few steps, the feature seems to be uniform. I use larger teacher temputure to enhance "sharping", but the model collapsed after all. I wonder if DINO is sensitive to the data, in other word, does DINO tend to collapse when training at differnet data?

mathildecaron31 commented 3 years ago

Hi @Doom9234

To enhance sharpening, you should use a lower temperature (for example 0.02 instead of the default 0.04) for the teacher.

mathildecaron31 commented 3 years ago

I did experimented with training DINO on Google Landmark v2 dataset (clean set), and also with random public non EU Instagram images. These are large datasets (at least of the size of ImageNet) and DINO has been working out of the box with the default parameters specified in this repo.

However, I have not experimented with smaller datasets and so I would be very curious to have your feedback on that.

Doom9234 commented 3 years ago

I did experimented with training DINO on Google Landmark v2 dataset (clean set), and also with random public non EU Instagram images. These are large datasets (at least of the size of ImageNet) and DINO has been working out of the box with the default parameters specified in this repo.

However, I have not experimented with smaller datasets and so I would be very curious to have your feedback on that.

Hi @mathildecaron31 , I used a non-public manually-labeled dataset, it has 44M images and 100K classes. I trained DINO use the default parameters specified in this repo. I used 48 GPUs to train DINO and the model soon collapsed(around 40k steps). Interestingly, the model has not collapsed very soon when I used 8 GPUs, and I dont know if the model would collapsed after more epochs (The model based on 8 GPUS is currently training).

However, I did not strictly follow the code in this repo, for example, I used NVIDIA/DALI to accelerate the training speed and torch.multiprocessing to distribute the model on different GPUs. So, I still working on the "collapsed problem", to find out whether my custom implmentation or the dataset lead to the model collapse.

mathildecaron31 commented 3 years ago

Hi @Doom9234

Sorry for the confusion, I don't know why I somehow imagined you were using a small dataset but it seems to be the opposite!

I see that you're training with 48 GPUs, what is the effective batch size then? I have not been able to train models at good performance with an effective batch size superior to 1024. I think large batch training would require to adapt the optimization to stabilize the training.

Btw, I think DALI is a great library, I've been using it quite a lot at some point myself :).

mathildecaron31 commented 3 years ago

Hi @Doom9234, have you been able to fix the issue ?

Doom9234 commented 3 years ago

Hi @mathildecaron31 , sorry for the late relpy, after some experment, I found batch size is the key h-parameter to solve the "collapse" problem. I downsample the 44M dataset to 2.5M(clean version) and I use the default parameters specified in this repo including batch size, the model finally converage. I used batch size:3076 to train DINO the first time which leaded to the model collapse. After I solved the collapsed problem, I trained 2 version of DINO for Content based Image Retrieval(CBIR) tasks : VIT_S_Patch16_300epochs_16GPUS & VIT_S_Patch8_800epochs_64GPUS, here are some conclusion based on my own testset: 1.The auc and map@10 of DINO surpass all the other SSL model based on CNN by a considerable gap, such as moco, byol, etc.. 2.VIT_S with patch size 8 is better than patch size 16. 3.Now I use DINO as the pretrained model to do some finetuning experment to verify the transfering performance of DINO.

Thanks for the great work, DINO shows great potential in CBIR tasks. Btw, I want to use larger batch size to accelerate the training process, can you give me some suggestions to stablize the model with larger batch size?

woctezuma commented 3 years ago

VIT_S with patch size 8 is better than patch size 16.

As expected. The ViT-S/16 offers the worst performance of the ViT-*/* models. The only big question is where ViT-S/8 lies compared to the ViT-B/* architectures, and this seems to be task-dependent and effort-dependent (as it is possible to do more experiments with a small model, and thus easier to push the performance). cf. #13

mathildecaron31 commented 3 years ago

Hi @Doom9234 Thanks for the update, I am happy to see that you've managed to get good results on your usecase!

Stabilizing training with large batches is an open problem so I don't have a good answer for you :/. Maybe you can check the LAMB optimizer https://arxiv.org/abs/1904.00962

swarajnanda2021 commented 6 months ago

Sadly, I have been struggling to solve this issue myself. I've got a custom implementation of the dino, but I've checked nearly every line possible to see if there is any difference. I get a collapse immediately, or later, depending on the learning rate.

Strangely it collapses always at the same number: 8.317765235900879

Here is what this looks like: image

Secondly, here is what my code looks like:

the loss function is a direct copy and paste from dino's repository, while the method is written in a pytorch lightning module, which is itself a direct copy and paste from another repository.

def cosine_scheduler(base_value, final_value, epochs, niter_per_ep,
                     warmup_epochs=0, start_warmup_value=1e-6):
    warmup_schedule = np.array([])
    warmup_iters = warmup_epochs * niter_per_ep
    if warmup_epochs > 0:
        warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)

    iters = np.arange(epochs * niter_per_ep - warmup_iters)
    schedule = final_value + 0.5 * (base_value - final_value) * \
                             (1 + np.cos(np.pi * iters / len(iters)))

    schedule = np.concatenate((warmup_schedule, schedule))
    assert len(schedule) == epochs * niter_per_ep
    return schedule

class DiNO(pl.LightningModule):
    def __init__(
            self,
            feature_size, 
            dinoheadout,
            encoder, 
            temperature_teacher, 
            final_teacher_temp,
            ncrops, 
            momentum_teacher,
            batch_size,
            num_gpus,
            lr,
            initial_weight_decay,
            final_weight_decay,
            warmup_epochs,
            max_epochs,
            iters_per_epoch,
    ):
        super().__init__()
        self.save_hyperparameters(ignore=['encoder'])
        self.automatic_optimization = False
        student_backbone = encoder
        self.teacher_backbone = encoder
        # set teacher_backbone dropoouts to zero
        self.teacher_backbone.attention_dropout, self.teacher_backbone.projection_dropout = 0.0, 0.0

        student_head = DiNOProjection(
                                    in_dim=feature_size,
                                    out_dim=dinoheadout,
                                    use_bn=False,
                                    norm_last_layer=False
                                    )
        teacher_head = DiNOProjection(
                                    in_dim=feature_size,
                                    out_dim=dinoheadout,
                                    use_bn=False,
                                    norm_last_layer=True
                                    )

        self.student = MultiCropWrapper(student_backbone, student_head)
        self.teacher = MultiCropWrapper(self.teacher_backbone, teacher_head)
        # teacher and student start with the same weights
        self.teacher.load_state_dict(self.student.state_dict())

        for p in self.teacher.parameters(): p.requires_grad = False

        self.loss = DiNOLoss(
                        out_dim     = dinoheadout,
                        ncrops      = ncrops + 2,
                        warmup_teacher_temp = temperature_teacher,
                        final_teacher_temp  = final_teacher_temp,
                        warmup_teacher_temp_epochs = 30,
                        nepochs     = max_epochs,
                        )

    def configure_optimizers(self):
        # First, ensure that weight decay does not apply to bias terms or norm params
        regularized, not_regularized = [], []
        for n, p in self.student.named_parameters():
            if not p.requires_grad:
                continue
            # we do not regularize biases nor Norm parameters
            if n.endswith(".bias") or len(p.shape) == 1:
                not_regularized.append(p)
            else:
                regularized.append(p)
        param_groups = [{'params': regularized                          },
                        {'params': not_regularized, 'weight_decay': 0.  }]

        lr = self.hparams.lr * (self.hparams.batch_size*self.hparams.num_gpus/256) # add accum gradient here later
        optimizer = LAMB(param_groups, lr=lr)

        return optimizer

    def setup(self, stage = None):

        # define schedulers based on number of iterations
        niter_per_ep = self.hparams.iters_per_epoch
        self.lr_sch = cosine_scheduler(self.hparams.lr, 
                                       1e-6, 
                                       self.hparams.max_epochs, 
                                       niter_per_ep//self.hparams.num_gpus,
                                       self.hparams.warmup_epochs)
        # weight decay scheduler
        self.wd_sch = cosine_scheduler(self.hparams.initial_weight_decay, 
                                       self.hparams.final_weight_decay,
                                       self.hparams.max_epochs, 
                                       niter_per_ep//self.hparams.num_gpus)
        # momentum scheduler
        self.mm_sch = cosine_scheduler(self.hparams.momentum_teacher, 
                                       1.0,
                                       self.hparams.max_epochs, 
                                       niter_per_ep//self.hparams.num_gpus)

    def training_step(self, batch, batch_idx):
        """
        batch: a list of "2+local_crops_number" tensors
               each tensor is of shape (B, 3, h, w)
        """

        imgs, _ = batch

        opt = self.optimizers()
        # update learning rate, weight decay
        for i, param_group in enumerate(opt.param_groups):
            param_group['lr'] = self.lr_sch[self.global_step]
            if i == 0: # only the first group is regularized
                param_group['weight_decay'] = self.wd_sch[self.global_step]

        teacher_output = self.teacher(imgs[:2])
        student_output = self.student(imgs)
        loss = self.loss(student_output, teacher_output, self.current_epoch)

        opt.zero_grad()
        self.manual_backward(loss)

        # Perform gradient clipping and freeze last layer of student if first epoch
        torch.nn.utils.clip_grad_norm_(self.student.parameters(), 0.5)
        if self.current_epoch < 1:
            for n, p in self.student.named_parameters():
                if 'last_layer' in n:
                    p.grad = None
        opt.step()

        # EMA update for the teacher
        m = self.mm_sch[self.global_step]
        for ps, pt in zip(self.student.parameters(), self.teacher.parameters()):
            pt.data.mul_(m).add_((1-m)*ps.detach().data)

        # Logging
        self.log('loss', loss, prog_bar=True)
        self.log('lr',self.trainer.optimizers[0].param_groups[0]['lr'],prog_bar=True)

The hyperparameters I have used are as follows:

batch_size: 24
dinoheadout: 4096
feature_size: 384
final_teacher_temp: 0.07
final_weight_decay: 0.4
initial_weight_decay: 0.04
iters_per_epoch: 55416
lr: 0.0005
max_epochs: 300
momentum_teacher: 0.996
ncrops: 4
num_gpus: 4
temperature_teacher: 0.001 (Have been trying to reduce this from 0.02 to 0.01 to 0.005 and now 0.001. No help)
warmup_epochs: 10
oggyfaker commented 2 months ago

Sadly, I have been struggling to solve this issue myself. I've got a custom implementation of the dino, but I've checked nearly every line possible to see if there is any difference. I get a collapse immediately, or later, depending on the learning rate.

Strangely it collapses always at the same number: 8.317765235900879

Here is what this looks like: image

Secondly, here is what my code looks like:

the loss function is a direct copy and paste from dino's repository, while the method is written in a pytorch lightning module, which is itself a direct copy and paste from another repository.

def cosine_scheduler(base_value, final_value, epochs, niter_per_ep,
                     warmup_epochs=0, start_warmup_value=1e-6):
    warmup_schedule = np.array([])
    warmup_iters = warmup_epochs * niter_per_ep
    if warmup_epochs > 0:
        warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)

    iters = np.arange(epochs * niter_per_ep - warmup_iters)
    schedule = final_value + 0.5 * (base_value - final_value) * \
                             (1 + np.cos(np.pi * iters / len(iters)))

    schedule = np.concatenate((warmup_schedule, schedule))
    assert len(schedule) == epochs * niter_per_ep
    return schedule

class DiNO(pl.LightningModule):
    def __init__(
            self,
            feature_size, 
            dinoheadout,
            encoder, 
            temperature_teacher, 
            final_teacher_temp,
            ncrops, 
            momentum_teacher,
            batch_size,
            num_gpus,
            lr,
            initial_weight_decay,
            final_weight_decay,
            warmup_epochs,
            max_epochs,
            iters_per_epoch,
    ):
        super().__init__()
        self.save_hyperparameters(ignore=['encoder'])
        self.automatic_optimization = False
        student_backbone = encoder
        self.teacher_backbone = encoder
        # set teacher_backbone dropoouts to zero
        self.teacher_backbone.attention_dropout, self.teacher_backbone.projection_dropout = 0.0, 0.0

        student_head = DiNOProjection(
                                    in_dim=feature_size,
                                    out_dim=dinoheadout,
                                    use_bn=False,
                                    norm_last_layer=False
                                    )
        teacher_head = DiNOProjection(
                                    in_dim=feature_size,
                                    out_dim=dinoheadout,
                                    use_bn=False,
                                    norm_last_layer=True
                                    )

        self.student = MultiCropWrapper(student_backbone, student_head)
        self.teacher = MultiCropWrapper(self.teacher_backbone, teacher_head)
        # teacher and student start with the same weights
        self.teacher.load_state_dict(self.student.state_dict())

        for p in self.teacher.parameters(): p.requires_grad = False

        self.loss = DiNOLoss(
                        out_dim     = dinoheadout,
                        ncrops      = ncrops + 2,
                        warmup_teacher_temp = temperature_teacher,
                        final_teacher_temp  = final_teacher_temp,
                        warmup_teacher_temp_epochs = 30,
                        nepochs     = max_epochs,
                        )

    def configure_optimizers(self):
        # First, ensure that weight decay does not apply to bias terms or norm params
        regularized, not_regularized = [], []
        for n, p in self.student.named_parameters():
            if not p.requires_grad:
                continue
            # we do not regularize biases nor Norm parameters
            if n.endswith(".bias") or len(p.shape) == 1:
                not_regularized.append(p)
            else:
                regularized.append(p)
        param_groups = [{'params': regularized                          },
                        {'params': not_regularized, 'weight_decay': 0.  }]

        lr = self.hparams.lr * (self.hparams.batch_size*self.hparams.num_gpus/256) # add accum gradient here later
        optimizer = LAMB(param_groups, lr=lr)

        return optimizer

    def setup(self, stage = None):

        # define schedulers based on number of iterations
        niter_per_ep = self.hparams.iters_per_epoch
        self.lr_sch = cosine_scheduler(self.hparams.lr, 
                                       1e-6, 
                                       self.hparams.max_epochs, 
                                       niter_per_ep//self.hparams.num_gpus,
                                       self.hparams.warmup_epochs)
        # weight decay scheduler
        self.wd_sch = cosine_scheduler(self.hparams.initial_weight_decay, 
                                       self.hparams.final_weight_decay,
                                       self.hparams.max_epochs, 
                                       niter_per_ep//self.hparams.num_gpus)
        # momentum scheduler
        self.mm_sch = cosine_scheduler(self.hparams.momentum_teacher, 
                                       1.0,
                                       self.hparams.max_epochs, 
                                       niter_per_ep//self.hparams.num_gpus)

    def training_step(self, batch, batch_idx):
        """
        batch: a list of "2+local_crops_number" tensors
               each tensor is of shape (B, 3, h, w)
        """

        imgs, _ = batch

        opt = self.optimizers()
        # update learning rate, weight decay
        for i, param_group in enumerate(opt.param_groups):
            param_group['lr'] = self.lr_sch[self.global_step]
            if i == 0: # only the first group is regularized
                param_group['weight_decay'] = self.wd_sch[self.global_step]

        teacher_output = self.teacher(imgs[:2])
        student_output = self.student(imgs)
        loss = self.loss(student_output, teacher_output, self.current_epoch)

        opt.zero_grad()
        self.manual_backward(loss)

        # Perform gradient clipping and freeze last layer of student if first epoch
        torch.nn.utils.clip_grad_norm_(self.student.parameters(), 0.5)
        if self.current_epoch < 1:
            for n, p in self.student.named_parameters():
                if 'last_layer' in n:
                    p.grad = None
        opt.step()

        # EMA update for the teacher
        m = self.mm_sch[self.global_step]
        for ps, pt in zip(self.student.parameters(), self.teacher.parameters()):
            pt.data.mul_(m).add_((1-m)*ps.detach().data)

        # Logging
        self.log('loss', loss, prog_bar=True)
        self.log('lr',self.trainer.optimizers[0].param_groups[0]['lr'],prog_bar=True)

The hyperparameters I have used are as follows:

batch_size: 24
dinoheadout: 4096
feature_size: 384
final_teacher_temp: 0.07
final_weight_decay: 0.4
initial_weight_decay: 0.04
iters_per_epoch: 55416
lr: 0.0005
max_epochs: 300
momentum_teacher: 0.996
ncrops: 4
num_gpus: 4
temperature_teacher: 0.001 (Have been trying to reduce this from 0.02 to 0.01 to 0.005 and now 0.001. No help)
warmup_epochs: 10

I think your code is quite good when implement by lightning. I also try on lightning and get the best result with loss == 0.3 - 0.5, and it will be lower if remain the training :D. I think you need to check these follows:

swarajnanda2021 commented 2 months ago

I found the issue. I should be using copy.deepcopy() in order to copy a separate version of the encoder when assigning the backbone to the student and the teacher.

The following lines were the culprit:

student_backbone = encoder self.teacher_backbone = encoder

It should instead have been:

student_backbone = encoder self.teacher_backbone = copy.deepcopy(encoder)

Else what happens is that the referencing in memory goes back to the same source, and this means the representations become identical after a few iterations.