tingxueronghua / pytorch-classification-advprop

MIT License
105 stars 16 forks source link

Questions about mixbn #2

Closed XiaoqiangZhou closed 3 years ago

XiaoqiangZhou commented 3 years ago

Thanks for your re-implementation!

I'm confused with one detail implementation in your code. Could you please help me?

In the code, you assume the input to the model consists of two parts, including main part and aux part, and you use two different bn layer to process each. However, it may be wrong to do so when combined with DataParallel.

For example, your whole batch is [main_1, main_2, main_3, main_4, aux_1, aux_2, aux_3, aux_4], and you adopt two GPU to train the model. In the training, the [main_1, main_2, main_3, main_4] will be assigned to the first GPU, while the [aux_1, aux_2, aux_3, aux_4] will be assigned to the second GPU. The code assumes the first half in each mini-batch is main part and second half is aux part, i.e.,[main_1, main_2] uses the main normalization layer and[main_3, main_4] uses the aux normalization layer, which may be wrong.

tingxueronghua commented 3 years ago

Thanks for you attention! To summarize, the implementation is correct. I am really glad that someone find the same problem as I did :) In fact, before I send the data into mixbn and get the outputs, I will do some operations. For example, considering input batch like [2*N, 3, 244, 244], before I input the data into the model using DataParallel, I will first reshape the data as [2, N, 3, 244, 244], then I transpose the first two dimension, and the data becomes [N, 2, 3, 244, 244]. And at this time it does right with DataParallel! Similar operations are also done when I get the output data. I feel really really excited that, someone notice this problem. Because I have stuck in this problem for about a month T_T. If you still find something is wrong, feel free to ask at any time.

tingxueronghua commented 3 years ago

Just to mention it, if don't handle with the data and DataParallel correctly, there will be no improvement on standard accuracy.

XiaoqiangZhou commented 3 years ago

Thanks for your quick reply and kind answering.

Obviously, I missed some important details. This is an elegant implementation.

Thanks for your great implementation again! 👍