facebookresearch / dino

PyTorch code for Vision Transformers training with the Self-Supervised learning method DINO
Apache License 2.0
6.25k stars 905 forks source link

a solution to solve memory issues (but slows down training a bit) #269

Open fawazsammani opened 9 months ago

fawazsammani commented 9 months ago

Not really an issue, just a solution which requires a lot less memory (18x less). I think it would be helpful for lots of people. So i'll post it:

MultiCropping eats a lot of GPU memory, because instead of saving 1 computation graph, you end up saving 18 computation graphs (18 is the n_loss_terms in the code below if the n_local_crops = 8). So just run every crop separately through the student and backprop with loss.backward() (don't update the weights with optimizer.step() yet, rather accumulate gradients for all global-local pairs). This will compute the gradients for every global-local pair and clear its computation graph before starting a new pair. After accumulating grads for all pairs, then run optimizer.step(). Using this implementation saves a lot of memory. I was able to use a large batch size and train it on a single GPU.

class DINOLoss(nn.Module):
    def __init__(self, out_dim = 65536, teacher_temp = 0.04, student_temp=0.1, center_momentum=0.9):

        super().__init__()
        self.teacher_temp = teacher_temp
        self.student_temp = student_temp
        self.center_momentum = center_momentum
        self.register_buffer("center", torch.zeros(1, out_dim))

    def forward(self, student, student_feats, teacher_output, epoch):
        """
        Cross-entropy between softmax outputs of the teacher and student networks.
        student_feats contains a list of tensors and len(student_feats) = n_local_crops + 2
        """
        teacher_out = F.softmax((teacher_output - self.center) / self.teacher_temp, dim=-1)
        teacher_out = teacher_out.detach().chunk(2)
        self.update_center(teacher_output)
        n_loss_terms = (len(teacher_out) * len(student_feats)) - len(teacher_out)
        total_loss = 0

        for iq, q in enumerate(teacher_out):
            for v, chunk in enumerate(student_feats):
                if iq == v:
                    continue
                student_output = student(chunk)   # forward computation graph
                student_output = student_output / self.student_temp
                loss = torch.sum(-q * F.log_softmax(student_output, dim=-1), dim=-1)
                loss = loss.mean() / n_loss_terms     
                loss.backward()        # accumulate grads and then clear computation graph
                total_loss += loss    # for printing 

        return total_loss
dino_loss = DINOLoss()
teacher_feats = torch.cat(student_feats[:2]).clone().detach() 
teacher_output = teacher(teacher_feats)  # only the 2 global views pass through the teacher
loss = dino_loss(student, student_feats, teacher_output, epoch)

Note that in the code student_feats are the images (they are named feats for another reason) Hope it helps :)