sahagobinda / GPM

Official [ICLR] Code Repository for "Gradient Projection Memory for Continual Learning"
MIT License
86 stars 16 forks source link

About batch normalization #9

Open pkuxmq opened 2 years ago

pkuxmq commented 2 years ago

Hi, thank you for your interesting and promising work. I have a question about the implementation of batch normalization. I wonder why in the code of all models, "track_running_stats" is set as "False" for all BN modules. This means that this module will always use the mean and variance of the current batch data during feedforward inference (both for training and testing). Then the testing of the model will also be influenced by the input batch size and the sequence of testing batches, which may lead to performance fluctuation under different test settings (for example, batch statistics will be inaccurate when batch size is 1 and BN is notorious for performance drop with small batch size). The common practice for BN is to track running mean and variance of the training data and BN is set to the "eval" mode during testing so that the statistics will use the tracked mean and variance, and the testing results will be consistent.

In the paper, it is said that "batch normalization parameters are learned for the first task and shared with all the other tasks (following Mallya & Lazebnik (2018))". However, I checked the code of PackNet (https://github.com/arunmallya/packnet), they do not set "track_running_stats" as "False", but they track the statistics of the first task and set BN to the "eval" mode for the remaining tasks so that statistics are fixed. So they follow the common practice of BN and the testing results will be consistent. I wonder if there is any additional consideration for setting "track_running_stats=False" in this implementation, and how would the model with this setting be influenced by different testing settings, e.g. with different test batch sizes?

muyuuuu commented 1 year ago

same question @pkuxmq

  1. could GPM get the same performance if set track_running_stats=True
  2. packnet, track the statistics of the first task and set BN to the "eval" mode for the remaining tasks so that statistics are fixed

why BN layers are treated specially......

muyuuuu commented 1 year ago

Additional:

GPM can get the bad performance if set track_running_stats=True, even if in first task...... why?

Parsifal133 commented 1 year ago

I noticed that: setting the track_ running_state to True will record the mean and variance of new task data, which are applied to the inference of the old task, thereby hindering the inference of the old task.