lyakaap / VAT-pytorch

Virtual Adversarial Training (VAT) implementation for PyTorch
296 stars 43 forks source link

What's the point of _disable_tracking_bn_stats()? #8

Closed jonkoi closed 5 years ago

jonkoi commented 5 years ago

I don't understand what _disable_tracking_bn_stats() is trying to do? I don't think the network itself has track_running_stats attribute for the condition to be met but, if any, you would have to go to the batchnorm.

I think you are trying to fix running_mean and running_var while getting the VAT Loss. But I don't think it works by changing track_running_stats, either.

jizongFox commented 4 years ago

The track_running_stats=False let you to ignore the statistics of current samples contributing to running_mean. However, in current PyTorch, there is a bug, that even with track_running_stats=False you are going to update running_mean. For the inference, setting track_running_stats=False let network to use current batch estimation, instead of running_mean.