Open cmartin-isla opened 3 years ago
I've had the exact same oscillation issue, even without cosine scheduling for learning rate, even within a single eopch. No idea either why I get this issue.
@cmartin-isla probably much too late for your uses, but I've run into a similar issue.
Currently having success with lowered learning rates, although that's not a full solution.
Regarding extracting attention, I found this visualization enlightening:
for cls_last, query_last in zip([True, True, False, False],
[True, False, False, True]):
with torch.no_grad():
batch = next(iter(dataloader))
batch = batch.to(device)
recorder = Recorder(learner.net)
preds, attns = recorder(batch)
if cls_last and query_last:
reshaped_attns = attns[..., :-1, -1]
elif cls_last and not query_last:
reshaped_attns = attns[..., -1, :-1]
elif not cls_last and query_last:
reshaped_attns = attns[..., 1:, 0]
else:
reshaped_attns = attns[..., 0, 1:]
reshaped_attns = reshaped_attns.mean(1).mean(1)
reshaped_attns = reshaped_attns.view(-1, patches_per_side, patches_per_side)
fig = plt.figure(figsize=(12, 8))
fig.suptitle(f'{cls_last=} {query_last=}')
for i_img in range(8):
plt.subplot(241 + i_img)
plt.imshow(batch[i_img, ...].squeeze().cpu(), cmap='gist_gray')
plt.imshow(reshaped_attns[i_img, ...].squeeze().cpu(), cmap='inferno',
alpha=0.2, extent=[0, image_size, image_size, 0])
plt.colorbar(orientation='horizontal')
plt.axis('off')
plt.show()
Did you end up finding solutions to your questions?
Might anyone have any updates on this? I'm about to start pre training a ViT using Dino but now I have doubts...
Might be the case our datasets need to be much larger: https://github.com/facebookresearch/dino/issues/196
Hello, I am trying to train DINO with a base ViT from scratch and I have some doubts. First of all, I think that in the original paper, the student temperature is 0.1 in the 30 epoch warmup, but I am not able to find the default value of 0.9 of this repo in the original paper.
After adding cosine scheduler and doing some tweaking, I manage to have a loss of the 1e-7, but this loss has an oscillation., every 3-4 epochs it reach a maximum around 20-30 and then decreases another time to new minima near the 1e-7 order. I don't know if you experienced that kind of behavior. Secondly, after training, I want to visualize the attention maps but I am not sure how to do that. Let's say that I have trained with 224x244 images and patch size of 8. I have attention maps of (1, 6, 8, 785, 785), so 28*28 patches + 1 cls token per patch = 785:
Thanks a lot, and again, thanks for this amazing repo.