nikitadurasov / masksembles

Official repository for the paper "Masksembles for Uncertainty Estimation" (CVPR 2021).
https://www.norange.io/projects/masksembles/
MIT License
98 stars 15 forks source link

How to use masksembles? #3

Open ZhouCX117 opened 3 years ago

ZhouCX117 commented 3 years ago

Hi, thank you for your excellent work! I have read the masksembles layer code, but I have a little issue with understanding how to use it. Does it mean replacing the dropout layer with masksembles layer in practice?

Could you give a more detailed example?

I have run the following test code sucessfully. layer = Masksembles1D(20, 10,2.) output=layer(torch.randn([40,20])) print(output)

nikitadurasov commented 3 years ago

Hey @ToBeNormal, thanks!

Could you elaborate a bit on that, what you would like to be added? I have plans to provide a tutorial on how to use Masksembles layers for eg CIFAR dataset, would it be good?

Does it mean replacing the dropout layer with masksembles layer in practice?

In general, yes, that's the simplest way to use it. You could start with inserting Masksembles layers instead of dropout layer, in our experiments it worked fine and brought improvement in quality of generated uncertainty.

On the other hand, if you would like to achieve single model -- ensembles transition that we've described in our paper then you need to increase the number of channels in your layers. Anyway, from a more practical point of view, I would recommend the first option though.

ZhouCX117 commented 3 years ago

The tutorial on how to use Masksembles layers for eg CIFAR dataset would be helpful enough! Thanks a lot!

nikitadurasov commented 3 years ago

hey @ToBeNormal, sorry for the long response. I've updated README with some examples and colab notebook with MNIST. Would be great if you can check it and comment if it works for you.

ZhouCX117 commented 3 years ago

Hi, sorry for the long response. Just now I run the code in the colab notebook, it works well. However, I have one question. Why should the batchsize split the M? In the example, M equals 4, and the test or train batchsizes must divide 4, such as 4, 8, 128.

Does the batchsize mean sample times during the test process?

ZhouCX117 commented 3 years ago

@nikitadurasov Hi, we can only test one sample every time during the test process, is it right? Does the batch size of the test have to be 1?

ZhouCX117 commented 3 years ago

@nikitadurasov Hi, I don't understand the meanings of the total model size and model size? Could you give me a favor? Thanks a lot!

ZhouCX117 commented 3 years ago

@nikitadurasov Hi, does masksembles layer bring learnable parameters? In table 1, the model size of MC-dropout is 1x while the model size of masksembles is 2.3x. I can't understand this well. Does the model size contain unlearnable parameters? Why doesn't MC-dropout equal to masksembles layers?

ZhouCX117 commented 3 years ago

@nikitadurasov Hi~Could you tell me the batch size during the training and testing process?

nikitadurasov commented 2 years ago

Hey @ToBeNormal,

Putting it in a straight way: when you have a batch of N samples and there are M models in Masksembles, then after inference you'll get a batch of predictions with N samples still. The trick is that because of the current implementation first N / M samples in predictions in the batch will be predictions corresponding to the first submodel, the second N / M predicted values from the second submodel, and so on. That's the reason why why it's required to N % M = 0

About total size, imagine that we have a simple NN model model with only one hidden layer. Input and output size are fixed (let's say they are I and O, and the hidden layer is H size). This way if you add Masksembles model after the hidden layer, then every submodel of Masksembles effectively will take less number of active neurons and in general, will have a smaller capacity. To avoid that we increase H for Masksembles model to make it fit submodels sizes to the original model.

Hope it would help!

WateverOk commented 10 months ago

Hello, do you have a Pytorch version of this code