Open ibro45 opened 1 year ago
Hi, our VICRegL implementation makes the assumption that the features passed to the loss have the same size as the grid. So if you are using a resnet18 with 224px input images you get 7x7xdim output features as resnet downsamples inputs by a factor 32. This is why the grid size in the transform is set to 7. Similarly, for the local views we have 96px inputs and grid size 3 as 96/32 = 3.
What we probably should do is add a check in the forward pass of the loss that verifies that the number of features and grid size are equal.
Oh, I see! Thank you!
Let's say we have resnet18 and a 224x96 global view instead of 224x224; in that case, the grid should be 7x3, right?
Yes :) Although grid size is currently fixed to a single number (same width as height), so 7x3 won't be possible.
Certainly, I had to override it for 3D anyway, so not an issue 😁 thank you so much!
Would it be useful to catch the error thoiguh? Simply checking if the max value of min_indices
is out of bound for input_maps
or candidate_maps
would do. Let me know and I'll do a PR
I believe we should test in vicregloss forward that the features have the same dimension as the grid: https://github.com/lightly-ai/lightly/blob/00461d139d2a4280d91cbbc0d4d32d8766ef4227/lightly/loss/vicregl_loss.py#L97
PR would be very much appreciated!
Reproduce
Run https://docs.lightly.ai/self-supervised-learning/examples/vicregl.html with
transform = VICRegLTransform(n_local_views=0, global_grid_size=9)
(default isglobal_grid_size=7
).If you run it on GPU, you'll get a very ugly CUDA error that doesn't say much. If you try it on CPU, you can backtrace it.
Description
Calculating
nearest_neighbors
inVICRegLLoss
's_nearest_neighbors_on_l2
and_nearest_neighbors_on_grid
causes the above error. What happens is that the calculated indices inmin_indices
are sometimes out of bound for the given input: https://github.com/lightly-ai/lightly/blob/dda9e8405d1460271e18d5200330b94bfb87c39f/lightly/models/utils.py#L518That's as far as I've gotten, I haven't looked further into it and do not understand what exactly is going on. I believe that your implementation is correct, but it would be useful to catch the error when the specified grid size is too big and provide a useful error message to the user.