Closed Doom9234 closed 3 years ago
Hi @Doom9234
To enhance sharpening, you should use a lower temperature (for example 0.02 instead of the default 0.04) for the teacher.
I did experimented with training DINO on Google Landmark v2 dataset (clean set), and also with random public non EU Instagram images. These are large datasets (at least of the size of ImageNet) and DINO has been working out of the box with the default parameters specified in this repo.
However, I have not experimented with smaller datasets and so I would be very curious to have your feedback on that.
I did experimented with training DINO on Google Landmark v2 dataset (clean set), and also with random public non EU Instagram images. These are large datasets (at least of the size of ImageNet) and DINO has been working out of the box with the default parameters specified in this repo.
However, I have not experimented with smaller datasets and so I would be very curious to have your feedback on that.
Hi @mathildecaron31 , I used a non-public manually-labeled dataset, it has 44M images and 100K classes. I trained DINO use the default parameters specified in this repo. I used 48 GPUs to train DINO and the model soon collapsed(around 40k steps). Interestingly, the model has not collapsed very soon when I used 8 GPUs, and I dont know if the model would collapsed after more epochs (The model based on 8 GPUS is currently training).
However, I did not strictly follow the code in this repo, for example, I used NVIDIA/DALI to accelerate the training speed and torch.multiprocessing to distribute the model on different GPUs. So, I still working on the "collapsed problem", to find out whether my custom implmentation or the dataset lead to the model collapse.
Hi @Doom9234
Sorry for the confusion, I don't know why I somehow imagined you were using a small dataset but it seems to be the opposite!
I see that you're training with 48 GPUs, what is the effective batch size then? I have not been able to train models at good performance with an effective batch size superior to 1024. I think large batch training would require to adapt the optimization to stabilize the training.
Btw, I think DALI is a great library, I've been using it quite a lot at some point myself :).
Hi @Doom9234, have you been able to fix the issue ?
Hi @mathildecaron31 , sorry for the late relpy, after some experment, I found batch size is the key h-parameter to solve the "collapse" problem. I downsample the 44M dataset to 2.5M(clean version) and I use the default parameters specified in this repo including batch size, the model finally converage. I used batch size:3076 to train DINO the first time which leaded to the model collapse. After I solved the collapsed problem, I trained 2 version of DINO for Content based Image Retrieval(CBIR) tasks : VIT_S_Patch16_300epochs_16GPUS & VIT_S_Patch8_800epochs_64GPUS, here are some conclusion based on my own testset: 1.The auc and map@10 of DINO surpass all the other SSL model based on CNN by a considerable gap, such as moco, byol, etc.. 2.VIT_S with patch size 8 is better than patch size 16. 3.Now I use DINO as the pretrained model to do some finetuning experment to verify the transfering performance of DINO.
Thanks for the great work, DINO shows great potential in CBIR tasks. Btw, I want to use larger batch size to accelerate the training process, can you give me some suggestions to stablize the model with larger batch size?
VIT_S with patch size 8 is better than patch size 16.
As expected. The ViT-S/16
offers the worst performance of the ViT-*/*
models. The only big question is where ViT-S/8
lies compared to the ViT-B/*
architectures, and this seems to be task-dependent and effort-dependent (as it is possible to do more experiments with a small model, and thus easier to push the performance). cf. #13
Hi @Doom9234 Thanks for the update, I am happy to see that you've managed to get good results on your usecase!
Stabilizing training with large batches is an open problem so I don't have a good answer for you :/. Maybe you can check the LAMB optimizer https://arxiv.org/abs/1904.00962
Sadly, I have been struggling to solve this issue myself. I've got a custom implementation of the dino, but I've checked nearly every line possible to see if there is any difference. I get a collapse immediately, or later, depending on the learning rate.
Strangely it collapses always at the same number: 8.317765235900879
Here is what this looks like:
Secondly, here is what my code looks like:
the loss function is a direct copy and paste from dino's repository, while the method is written in a pytorch lightning module, which is itself a direct copy and paste from another repository.
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep,
warmup_epochs=0, start_warmup_value=1e-6):
warmup_schedule = np.array([])
warmup_iters = warmup_epochs * niter_per_ep
if warmup_epochs > 0:
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
iters = np.arange(epochs * niter_per_ep - warmup_iters)
schedule = final_value + 0.5 * (base_value - final_value) * \
(1 + np.cos(np.pi * iters / len(iters)))
schedule = np.concatenate((warmup_schedule, schedule))
assert len(schedule) == epochs * niter_per_ep
return schedule
class DiNO(pl.LightningModule):
def __init__(
self,
feature_size,
dinoheadout,
encoder,
temperature_teacher,
final_teacher_temp,
ncrops,
momentum_teacher,
batch_size,
num_gpus,
lr,
initial_weight_decay,
final_weight_decay,
warmup_epochs,
max_epochs,
iters_per_epoch,
):
super().__init__()
self.save_hyperparameters(ignore=['encoder'])
self.automatic_optimization = False
student_backbone = encoder
self.teacher_backbone = encoder
# set teacher_backbone dropoouts to zero
self.teacher_backbone.attention_dropout, self.teacher_backbone.projection_dropout = 0.0, 0.0
student_head = DiNOProjection(
in_dim=feature_size,
out_dim=dinoheadout,
use_bn=False,
norm_last_layer=False
)
teacher_head = DiNOProjection(
in_dim=feature_size,
out_dim=dinoheadout,
use_bn=False,
norm_last_layer=True
)
self.student = MultiCropWrapper(student_backbone, student_head)
self.teacher = MultiCropWrapper(self.teacher_backbone, teacher_head)
# teacher and student start with the same weights
self.teacher.load_state_dict(self.student.state_dict())
for p in self.teacher.parameters(): p.requires_grad = False
self.loss = DiNOLoss(
out_dim = dinoheadout,
ncrops = ncrops + 2,
warmup_teacher_temp = temperature_teacher,
final_teacher_temp = final_teacher_temp,
warmup_teacher_temp_epochs = 30,
nepochs = max_epochs,
)
def configure_optimizers(self):
# First, ensure that weight decay does not apply to bias terms or norm params
regularized, not_regularized = [], []
for n, p in self.student.named_parameters():
if not p.requires_grad:
continue
# we do not regularize biases nor Norm parameters
if n.endswith(".bias") or len(p.shape) == 1:
not_regularized.append(p)
else:
regularized.append(p)
param_groups = [{'params': regularized },
{'params': not_regularized, 'weight_decay': 0. }]
lr = self.hparams.lr * (self.hparams.batch_size*self.hparams.num_gpus/256) # add accum gradient here later
optimizer = LAMB(param_groups, lr=lr)
return optimizer
def setup(self, stage = None):
# define schedulers based on number of iterations
niter_per_ep = self.hparams.iters_per_epoch
self.lr_sch = cosine_scheduler(self.hparams.lr,
1e-6,
self.hparams.max_epochs,
niter_per_ep//self.hparams.num_gpus,
self.hparams.warmup_epochs)
# weight decay scheduler
self.wd_sch = cosine_scheduler(self.hparams.initial_weight_decay,
self.hparams.final_weight_decay,
self.hparams.max_epochs,
niter_per_ep//self.hparams.num_gpus)
# momentum scheduler
self.mm_sch = cosine_scheduler(self.hparams.momentum_teacher,
1.0,
self.hparams.max_epochs,
niter_per_ep//self.hparams.num_gpus)
def training_step(self, batch, batch_idx):
"""
batch: a list of "2+local_crops_number" tensors
each tensor is of shape (B, 3, h, w)
"""
imgs, _ = batch
opt = self.optimizers()
# update learning rate, weight decay
for i, param_group in enumerate(opt.param_groups):
param_group['lr'] = self.lr_sch[self.global_step]
if i == 0: # only the first group is regularized
param_group['weight_decay'] = self.wd_sch[self.global_step]
teacher_output = self.teacher(imgs[:2])
student_output = self.student(imgs)
loss = self.loss(student_output, teacher_output, self.current_epoch)
opt.zero_grad()
self.manual_backward(loss)
# Perform gradient clipping and freeze last layer of student if first epoch
torch.nn.utils.clip_grad_norm_(self.student.parameters(), 0.5)
if self.current_epoch < 1:
for n, p in self.student.named_parameters():
if 'last_layer' in n:
p.grad = None
opt.step()
# EMA update for the teacher
m = self.mm_sch[self.global_step]
for ps, pt in zip(self.student.parameters(), self.teacher.parameters()):
pt.data.mul_(m).add_((1-m)*ps.detach().data)
# Logging
self.log('loss', loss, prog_bar=True)
self.log('lr',self.trainer.optimizers[0].param_groups[0]['lr'],prog_bar=True)
The hyperparameters I have used are as follows:
batch_size: 24
dinoheadout: 4096
feature_size: 384
final_teacher_temp: 0.07
final_weight_decay: 0.4
initial_weight_decay: 0.04
iters_per_epoch: 55416
lr: 0.0005
max_epochs: 300
momentum_teacher: 0.996
ncrops: 4
num_gpus: 4
temperature_teacher: 0.001 (Have been trying to reduce this from 0.02 to 0.01 to 0.005 and now 0.001. No help)
warmup_epochs: 10
Sadly, I have been struggling to solve this issue myself. I've got a custom implementation of the dino, but I've checked nearly every line possible to see if there is any difference. I get a collapse immediately, or later, depending on the learning rate.
Strangely it collapses always at the same number: 8.317765235900879
Here is what this looks like:
Secondly, here is what my code looks like:
the loss function is a direct copy and paste from dino's repository, while the method is written in a pytorch lightning module, which is itself a direct copy and paste from another repository.
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=1e-6): warmup_schedule = np.array([]) warmup_iters = warmup_epochs * niter_per_ep if warmup_epochs > 0: warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) iters = np.arange(epochs * niter_per_ep - warmup_iters) schedule = final_value + 0.5 * (base_value - final_value) * \ (1 + np.cos(np.pi * iters / len(iters))) schedule = np.concatenate((warmup_schedule, schedule)) assert len(schedule) == epochs * niter_per_ep return schedule class DiNO(pl.LightningModule): def __init__( self, feature_size, dinoheadout, encoder, temperature_teacher, final_teacher_temp, ncrops, momentum_teacher, batch_size, num_gpus, lr, initial_weight_decay, final_weight_decay, warmup_epochs, max_epochs, iters_per_epoch, ): super().__init__() self.save_hyperparameters(ignore=['encoder']) self.automatic_optimization = False student_backbone = encoder self.teacher_backbone = encoder # set teacher_backbone dropoouts to zero self.teacher_backbone.attention_dropout, self.teacher_backbone.projection_dropout = 0.0, 0.0 student_head = DiNOProjection( in_dim=feature_size, out_dim=dinoheadout, use_bn=False, norm_last_layer=False ) teacher_head = DiNOProjection( in_dim=feature_size, out_dim=dinoheadout, use_bn=False, norm_last_layer=True ) self.student = MultiCropWrapper(student_backbone, student_head) self.teacher = MultiCropWrapper(self.teacher_backbone, teacher_head) # teacher and student start with the same weights self.teacher.load_state_dict(self.student.state_dict()) for p in self.teacher.parameters(): p.requires_grad = False self.loss = DiNOLoss( out_dim = dinoheadout, ncrops = ncrops + 2, warmup_teacher_temp = temperature_teacher, final_teacher_temp = final_teacher_temp, warmup_teacher_temp_epochs = 30, nepochs = max_epochs, ) def configure_optimizers(self): # First, ensure that weight decay does not apply to bias terms or norm params regularized, not_regularized = [], [] for n, p in self.student.named_parameters(): if not p.requires_grad: continue # we do not regularize biases nor Norm parameters if n.endswith(".bias") or len(p.shape) == 1: not_regularized.append(p) else: regularized.append(p) param_groups = [{'params': regularized }, {'params': not_regularized, 'weight_decay': 0. }] lr = self.hparams.lr * (self.hparams.batch_size*self.hparams.num_gpus/256) # add accum gradient here later optimizer = LAMB(param_groups, lr=lr) return optimizer def setup(self, stage = None): # define schedulers based on number of iterations niter_per_ep = self.hparams.iters_per_epoch self.lr_sch = cosine_scheduler(self.hparams.lr, 1e-6, self.hparams.max_epochs, niter_per_ep//self.hparams.num_gpus, self.hparams.warmup_epochs) # weight decay scheduler self.wd_sch = cosine_scheduler(self.hparams.initial_weight_decay, self.hparams.final_weight_decay, self.hparams.max_epochs, niter_per_ep//self.hparams.num_gpus) # momentum scheduler self.mm_sch = cosine_scheduler(self.hparams.momentum_teacher, 1.0, self.hparams.max_epochs, niter_per_ep//self.hparams.num_gpus) def training_step(self, batch, batch_idx): """ batch: a list of "2+local_crops_number" tensors each tensor is of shape (B, 3, h, w) """ imgs, _ = batch opt = self.optimizers() # update learning rate, weight decay for i, param_group in enumerate(opt.param_groups): param_group['lr'] = self.lr_sch[self.global_step] if i == 0: # only the first group is regularized param_group['weight_decay'] = self.wd_sch[self.global_step] teacher_output = self.teacher(imgs[:2]) student_output = self.student(imgs) loss = self.loss(student_output, teacher_output, self.current_epoch) opt.zero_grad() self.manual_backward(loss) # Perform gradient clipping and freeze last layer of student if first epoch torch.nn.utils.clip_grad_norm_(self.student.parameters(), 0.5) if self.current_epoch < 1: for n, p in self.student.named_parameters(): if 'last_layer' in n: p.grad = None opt.step() # EMA update for the teacher m = self.mm_sch[self.global_step] for ps, pt in zip(self.student.parameters(), self.teacher.parameters()): pt.data.mul_(m).add_((1-m)*ps.detach().data) # Logging self.log('loss', loss, prog_bar=True) self.log('lr',self.trainer.optimizers[0].param_groups[0]['lr'],prog_bar=True)
The hyperparameters I have used are as follows:
batch_size: 24 dinoheadout: 4096 feature_size: 384 final_teacher_temp: 0.07 final_weight_decay: 0.4 initial_weight_decay: 0.04 iters_per_epoch: 55416 lr: 0.0005 max_epochs: 300 momentum_teacher: 0.996 ncrops: 4 num_gpus: 4 temperature_teacher: 0.001 (Have been trying to reduce this from 0.02 to 0.01 to 0.005 and now 0.001. No help) warmup_epochs: 10
I think your code is quite good when implement by lightning. I also try on lightning and get the best result with loss == 0.3 - 0.5, and it will be lower if remain the training :D. I think you need to check these follows:
I found the issue. I should be using copy.deepcopy() in order to copy a separate version of the encoder when assigning the backbone to the student and the teacher.
The following lines were the culprit:
student_backbone = encoder self.teacher_backbone = encoder
It should instead have been:
student_backbone = encoder self.teacher_backbone = copy.deepcopy(encoder)
Else what happens is that the referencing in memory goes back to the same source, and this means the representations become identical after a few iterations.
I use custom data to train DINO, the model seems collapsed after a few steps, the feature seems to be uniform. I use larger teacher temputure to enhance "sharping", but the model collapsed after all. I wonder if DINO is sensitive to the data, in other word, does DINO tend to collapse when training at differnet data?