lucidrains / vit-pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch
MIT License
20.39k stars 3.03k forks source link

Training DINO #154

Open cmartin-isla opened 3 years ago

cmartin-isla commented 3 years ago

Hello, I am trying to train DINO with a base ViT from scratch and I have some doubts. First of all, I think that in the original paper, the student temperature is 0.1 in the 30 epoch warmup, but I am not able to find the default value of 0.9 of this repo in the original paper.

After adding cosine scheduler and doing some tweaking, I manage to have a loss of the 1e-7, but this loss has an oscillation., every 3-4 epochs it reach a maximum around 20-30 and then decreases another time to new minima near the 1e-7 order. I don't know if you experienced that kind of behavior. Secondly, after training, I want to visualize the attention maps but I am not sure how to do that. Let's say that I have trained with 224x244 images and patch size of 8. I have attention maps of (1, 6, 8, 785, 785), so 28*28 patches + 1 cls token per patch = 785:

Thanks a lot, and again, thanks for this amazing repo.

adeschemps commented 2 years ago

I've had the exact same oscillation issue, even without cosine scheduling for learning rate, even within a single eopch. No idea either why I get this issue.

alexhagen commented 2 years ago

@cmartin-isla probably much too late for your uses, but I've run into a similar issue.

Currently having success with lowered learning rates, although that's not a full solution.

Regarding extracting attention, I found this visualization enlightening:

for cls_last, query_last in zip([True, True, False, False],
                                [True, False, False, True]):
    with torch.no_grad():
        batch = next(iter(dataloader))
        batch = batch.to(device)
        recorder = Recorder(learner.net)
        preds, attns = recorder(batch)
        if cls_last and query_last:
            reshaped_attns = attns[..., :-1, -1]
        elif cls_last and not query_last:
            reshaped_attns = attns[..., -1, :-1]
        elif not cls_last and query_last:
            reshaped_attns = attns[..., 1:, 0]
        else:
            reshaped_attns = attns[..., 0, 1:]
        reshaped_attns = reshaped_attns.mean(1).mean(1)
        reshaped_attns = reshaped_attns.view(-1, patches_per_side, patches_per_side)
        fig = plt.figure(figsize=(12, 8))
        fig.suptitle(f'{cls_last=} {query_last=}')
        for i_img in range(8):
            plt.subplot(241 + i_img)
            plt.imshow(batch[i_img, ...].squeeze().cpu(), cmap='gist_gray')
            plt.imshow(reshaped_attns[i_img, ...].squeeze().cpu(), cmap='inferno',
                       alpha=0.2, extent=[0, image_size, image_size, 0])
            plt.colorbar(orientation='horizontal')
            plt.axis('off')
        plt.show()

Did you end up finding solutions to your questions?

RobertHua96 commented 2 years ago

Might anyone have any updates on this? I'm about to start pre training a ViT using Dino but now I have doubts...

RobertHua96 commented 2 years ago

Might be the case our datasets need to be much larger: https://github.com/facebookresearch/dino/issues/196