JiahuiYu / slimmable_networks

Slimmable Networks, AutoSlim, and Beyond, ICLR 2019, and ICCV 2019
Other
911 stars 131 forks source link

How to collect post statistics of BN #10

Closed VectorYoung closed 5 years ago

VectorYoung commented 5 years ago

Hi Jiahui, I was trying to reproduce the USNet, but I encountered some issues in computing post BN statistics. After traning(actually I use your released model), I first randomly assign a width, then model.apply(width). Then just like training, I use the model to do inference to collect BN statistics. But after that, the USBN.running_mean and running_var are still empty. Basically, I follow the procedures here. I have set the model.train() and track_running_stats=True. Can you help me with that? Thanks.

JiahuiYu commented 5 years ago

Hi @VectorYoung ,

Thanks for your interest. For PyTorch I do following:

  1. reset bn statistics
  2. put model in training mode
  3. sample training images and forward for several times
  4. do inference

If your running_mean and running_var are empty, I guess you have met an implementation bug.

VectorYoung commented 5 years ago

Hi Jiahui, Thanks for your reply. I am trying to implement the USNet based on your released codes. So in your codes the forward function of USBatchNorm is y = nn.functional.batch_norm( input, self.running_mean, self.running_var, weight[:c], bias[:c], self.training, self.momentum, self.eps) I just follow the way you do forward. So If I set the model to training mode, and USBatchNorm's track_running_stats = True, I am expecting it will update the self.running_mean and running_var. But it doesn't. Can you give me some ideas where went wrong? I am new to pytorch so please forgive my naive questions.

JiahuiYu commented 5 years ago

@VectorYoung

If you set some random width in USNet, then the forward should go into this, right? https://github.com/JiahuiYu/slimmable_networks/blob/85d14e882faae06747363642f81bee429348a089/models/slimmable_ops.py#L179-L189

VectorYoung commented 5 years ago

Hi @JiahuiYu If I set a random width which is not in FLAGS.width_mult_list, I think it should go to the 'else' segment, and that should be the case when I want to do inference for arbitrary widths, right?. Thanks.

JiahuiYu commented 5 years ago

@VectorYoung In this case, you need to put your desired width into FLAGS.width_mult_list in config.

VectorYoung commented 5 years ago

Hi @JiahuiYu

Thanks for your help. I kind of figuring out. Just another simple question, I try to set momentum=none to get the exact average(since you said it is better than moving average), but the nn.functional.batchnorm can not take momentum=Nonertype. So how do you achieve the exact average? Do you set the momentum=(t-1)/t for each batch(t stands for the batch index)?

JiahuiYu commented 5 years ago

14

semin-park commented 5 years ago

Hi @VectorYoung ,

Thanks for your interest. For PyTorch I do following:

  1. reset bn statistics
  2. put model in training mode
  3. sample training images and forward for several times
  4. do inference

If your running_mean and running_var are empty, I guess you have met an implementation bug.

Hi, I have a small clarification question regarding this.

In number 3, by "sample training images and forward" do you mean that I should repeat this multiple times for random widths?

Also, in your apps/<...>.yml files, you have width_mult_list defined as something like this:

width_mult_list: [0.25, 0.275, 0.3, 0.325, 0.35, 0.375, 0.4, 0.425, 0.45, 0.475, 0.5, 0.525, 0.55, 0.575, 0.6, 0.625, 0.65, 0.675, 0.7, 0.725, 0.75, 0.775, 0.8, 0.825, 0.85, 0.875, 0.9, 0.925, 0.95, 0.975, 1.0]

By forwarding the model multiple times with random widths, is it assumed that correct running BN stats are accumulated for all of the widths defined in the list?