xinghaochen / SLAB

[ICML 2024] Official PyTorch implementation of "SLAB: Efficient Transformers with Simplified Linear Attention and Progressive Re-parameterized Batch Normalization"
79 stars 6 forks source link

Setting about step #6

Open Journey7331 opened 2 months ago

Journey7331 commented 2 months ago

Hi, @xinghaochen @guojialong1 Thanks for your great work!

In your code slab_deit.py, step = 60000

linearnorm = partial(LinearNorm, norm1=ln, norm2=RepBN, step=60000)

If I understand correctly, when training on ImageNet-1K, linearnorm will totally turn into RepBN when epoch=60000/1252 ≈ 47 with setting like batch_size=1024, len(train_dataloader)=1252, step=60000

But in my training, this setting still got nan (just like only use BN form start), so is there any more general setting like step ≈ (epoch * len(train_dataloader)) * (1/2 or 2/3 or 3/4) ?

Also, in my humble opinion, 47 may be not large enough to ensure stable training? or if anything I set was wrong?

guojialong1 commented 2 months ago

Can you provide more details on your training setting?

Journey7331 commented 2 months ago

Seems to work after setting step ≈ (epoch * len(train_dataloader)) * (4/5), Thanks for your reply :)

Journey7331 commented 2 months ago

Hi, @guojialong1 Sorry for bothering, but I noticed that you train another 10 epochs for PRepBN, as mentioned at page 5:

5.1. Image Classification
...
The linear decay steps T for PRepBN slightly varies across different backbones. 
Due to the variance shift induced by droppath, 
we freeze the model parameters and exclusively update the statistics of 
re-parameterized BatchNorm for 10 epochs at the end of training.

I'm curious about the exact impact of doing or not doing this another 10 epochs, can you share some results of this part? thx! :)

guojialong1 commented 2 months ago

In our results, the accuracy can be improved slightly. For example, the accuracy of Swin-T is improved by 0.1% with the help of another 10 epochs for PRepBN.

Journey7331 commented 2 months ago

Got it, many thanks!

Journey7331 commented 2 months ago

Hi, @guojialong1 Sorry for another bothering

When I plugged backbone that trained with PRepBN on ImageNet-1K into downstream task like segmentation (sem-fpn on ADE20K), the linear-norm is actually RepBN since iter = 0 is reloaded from ckpt.

Training goes well in the first 2 epochs, but mIoU(val) turn to 0 at epoch 2, and loss turn to nan from epoch 4.

Have you ever had this kind of problem?

logs:

epoch 0, `all goes well`
Iter(train) [   50/80000]  lr: 9.9945e-05  eta: 10:50:26  loss: 2.6102  decode.loss_ce: 2.6102  decode.acc_seg: 58.4337
Iter(train) [  100/80000]  lr: 9.9889e-05  eta: 8:36:52   loss: 2.0197  decode.loss_ce: 2.0197  decode.acc_seg: 59.5736
Iter(train) [  150/80000]  lr: 9.9832e-05  eta: 7:54:40   loss: 1.9503  decode.loss_ce: 1.9503  decode.acc_seg: 49.1268
...
+---------------------+-------+-------+
|        Class        |  IoU  |  Acc  |
+---------------------+-------+-------+
|         wall        | 67.05 | 85.94 |
|       building      | 76.66 | 93.02 |
|         sky         | 92.85 | 96.48 |
aAcc: 75.2300  mIoU: 28.2500  mAcc: 37.2000  data_time: 0.0019  time: 0.0709

epoch 1, `all goes well`
...
aAcc: 77.5900  mIoU: 34.3300  mAcc: 44.3400  data_time: 0.0017  time: 0.0385

epoch 2, `train goes well, val get IoU=0`
Iter(train) [16050/80000]  lr: 8.1749e-05  eta: 5:13:02  loss: 0.6718  decode.loss_ce: 0.6718  decode.acc_seg: 70.1025
Iter(train) [16100/80000]  lr: 8.1691e-05  eta: 5:12:47  loss: 0.6537  decode.loss_ce: 0.6537  decode.acc_seg: 71.3643
Iter(train) [16150/80000]  lr: 8.1634e-05  eta: 5:12:32  loss: 0.7025  decode.loss_ce: 0.7025  decode.acc_seg: 67.0324
...
+---------------------+-------+-------+
|        Class        |  IoU  |  Acc  |
+---------------------+-------+-------+
|         wall        | 17.54 | 100.0 |
|       building      |  0.0  |  0.0  |
|         sky         |  0.0  |  0.0  |
aAcc: 17.5400  mIoU: 0.1200  mAcc: 0.6700  data_time: 0.0017  time: 0.0381

epoch 3, `train goes well, val get IoU=0`
Iter(train) [24050/80000]  lr: 7.2484e-05  eta: 4:33:56  loss: 0.5841  decode.loss_ce: 0.5841  decode.acc_seg: 82.0962
Iter(train) [24100/80000]  lr: 7.2426e-05  eta: 4:33:42  loss: 0.5743  decode.loss_ce: 0.5743  decode.acc_seg: 85.1667
Iter(train) [24150/80000]  lr: 7.2368e-05  eta: 4:33:28  loss: 0.6144  decode.loss_ce: 0.6144  decode.acc_seg: 78.5210
...
+---------------------+-------+-------+
|        Class        |  IoU  |  Acc  |
+---------------------+-------+-------+
|         wall        | 17.54 | 100.0 |
|       building      |  0.0  |  0.0  |
|         sky         |  0.0  |  0.0  |
aAcc: 17.5400  mIoU: 0.1200  mAcc: 0.6700  data_time: 0.0017  time: 0.0407

epoch 4, `train get nan, val get IoU=0`
Iter(train) [32050/80000]  lr: 6.3086e-05  eta: 3:55:07  loss: 0.5303  decode.loss_ce: 0.5303  decode.acc_seg: 79.6364
Iter(train) [32100/80000]  lr: 6.3027e-05  eta: 3:54:52  loss: 0.5110  decode.loss_ce: 0.5110  decode.acc_seg: 73.6874
Iter(train) [32150/80000]  lr: 6.2968e-05  eta: 3:54:37  loss: 0.6026  decode.loss_ce: 0.6026  decode.acc_seg: 63.3229
...
Iter(train) [35000/80000]  lr: 5.9582e-05  eta: 3:41:00  loss: nan  decode.loss_ce: nan  decode.acc_seg: 75.3267
Iter(train) [35050/80000]  lr: 5.9522e-05  eta: 3:40:45  loss: 0.5658  decode.loss_ce: 0.5658  decode.acc_seg: 83.2890
Iter(train) [35100/80000]  lr: 5.9463e-05  eta: 3:40:30  loss: 0.4999  decode.loss_ce: 0.4999  decode.acc_seg: 81.7572
Iter(train) [35150/80000]  lr: 5.9403e-05  eta: 3:40:15  loss: nan  decode.loss_ce: nan  decode.acc_seg: 77.0405
Iter(train) [35200/80000]  lr: 5.9344e-05  eta: 3:40:00  loss: nan  decode.loss_ce: nan  decode.acc_seg: 3.3515
Iter(train) [35250/80000]  lr: 5.9284e-05  eta: 3:39:45  loss: nan  decode.loss_ce: nan  decode.acc_seg: 20.7496
...
^C
KeyboardInterrupt
guojialong1 commented 2 months ago

Do you freeze the parameters of RepBN and apply eval() mode to BatchNorm during training?

Journey7331 commented 2 months ago

Yes, I did it.

class ModelBackbone(Model):
  def __init__(self, out_indices=None, pretrained=None, **kwargs):
      super().__init__(**kwargs)
      self.out_indices = out_indices
      self.load_pretrained(pretrained)
      self.train()

  def train(self, mode=True):
      """Convert the model into training mode while keep layers freezed."""
      super().train(mode)
      if mode:
          for m in self.modules():
              if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
                  m.eval()
guojialong1 commented 2 months ago

Perhaps you can try freezing the parameter of alpha in RepBN, increasing the value of weight decay or using the strategy of PRepBN instead of RepBN.