Closed Yonghongwei closed 2 years ago
No, I don't think it's equivalent to SGD. SGD does not work for GAN while AdaBelief works. Depends on what default eps you are using, I'm using 1e-16, if you set as 1e-8 it would be the case as you mentioned. PS: EAdam is inspired by AdaBelief
Thanks for your reply. For the classification and detection tasks, I find this problem incurs. I have not checked the problem on other tasks. On these high-level vision tasks, SGD usually performs well. I also find small eps on these tasks will drop the performance. Therefore, I think the main reason Adabelief works well on high-level vision tasks is it performs like SGD.
Thanks for reply. It's possible but not always true, I have tried eps=1e-16 with ResNet18 on ImageNet with a batchsize 4096, it achieved above 72% top-1 accuracy, much better than with eps=1e-8 and even better than SGD. I'm not sure if it's a problem with framework since I used the Flax implementation to use TPU. I feel like even for classification tasks, with proper decoupled decay and lr schedule, adaptive optimizers can perform no worse than SGD, but it would require some tuning.
Please print any statistics of the variable 'exp_avg_var' (e.g., print(exp_avg_var.min())) for each parameter group. You will find the adaptive learning rates for different parameters are almost the same.
Please note this line of the code "denom = (exp_avgvar.add(group['eps']).sqrt()/ math.sqrt(biascorrection2)).add(group['eps'])". The code 'add_()' operator in 'exp_avgvar.add(group['eps'])' will change the value of the variable 'exp_avg_var'. As a result, the value of exp_avg_var will accumulate with eps in each iteration. The values of exp_avg_var for different parameters will constantly increase and be almost same. So actually, your method is just equivalent to SGD with a changable global learning rate. I guess you refer to EAdam to introduce eps, which also has this problem. It is a very severe problem. Because the 'belief' mentioned in your paper does not give any contribution to the final performance in your code.