Closed leftthomas closed 4 years ago
I think the batch size you set is not 32 per GPU? The self.B in MABN should equal to batch size you set in dataloader. We will refine the code later, but for now, you can manually set the default value of self.B to match the batch size you set for a quick test.
@RuosOne MABN
is aimed to solve the problem when batch size is small, so if I can set batch size to 32, why I use MABN
? You give a temporary solution but if batch size is 1 or odd number, it's still have error
The cls code simulates the small batch issue, i.e., large sgd batch size and small normalization batch size. In cls code ,the normalization batch is set as 2 while sgd batch size is 256. Of course you can set normalization batch size as 1 or any other odd numbers you want.
Technically, we hardly can meet small batch issue on classification problem, thus we just test the effectiveness of small batch normalization method on classification task as any other small batch normalization methods do. If you want to see its effectiveness on real small batch issue, I suggest you to test det code. It's a real small batch case.
Again, the cls code is just a simulation of small batch issue on imagenet cls tasks. If you want to test MABN on real small batch issue, you can test det code, or run the real small batch(1 or 2) experiment on imagenet with the implementation of MABN in det code, which I do not suggest you to do because it's really time-consuming.
Hope these can help you. If you have any other question not about the code but MABN itself, please send email to us or leave the public comments in https://openreview.net/forum?id=SkgGjRVKDS¬eId=38ji3hend1, we welcome any thoughtful discussion.
Edit: we will add formal MABN module in the repo later for easy use in real small batch issue.
We have added formal MABN module for easy use.
According to your code, the
sta_matrix
is a [16, 32] shape tensor https://github.com/megvii-model/MABN/blob/19e897f199408c00cd8e961d2af73e02ce8178c4/cls/networks/MABN.py#L82 But thevar
is a [16+N//2, C] shape tensor, thetorch.mm
API totally got errors to compute this matrix multiple of this two tensor. https://github.com/megvii-model/MABN/blob/19e897f199408c00cd8e961d2af73e02ce8178c4/cls/networks/MABN.py#L21