aws-samples / deep-learning-models

Natural language processing & computer vision models optimized for AWS
Other
140 stars 75 forks source link

a higher test accuracy training script #13

Open Tron-x opened 4 years ago

Tron-x commented 4 years ago

Hi @jarednielsen , I'm running 32K batchsize training, Imagenet with resnet50, but the test accuracy in this script https://github.com/aws-samples/deep-learning-models/blob/master/legacy/models/resnet/tensorflow/train_imagenet_resnet_hvd.py is 75.4 in my test. Can you provide a higher test accuracy training script, such as 76%. I've also made some attempts, but it's a little difficult for me. Thank you!

jarednielsen commented 4 years ago

Hi @Tron-x, can you provide more details about the training environment setup (EC2 or SageMaker or EKS), how many nodes you're using and which type of node, and the hyperparameters you're running with?

Tron-x commented 4 years ago

hyperparameters i am running with is https://github.com/aws-samples/deep-learning-models/blob/master/legacy/models/resnet/tensorflow/dlami_scripts/train_more_aug.sh,I have used 128 nodes with a total of 1024 GPUs, and the batchsize of each GPU is 32

Tron-x commented 4 years ago

hi, @jarednielsen ,https://arxiv.org/abs/1812.01187,there are some tricks in this paper, which I think can be added to the benchmark script to improve the accuracy, but it may be because my level is limited, and the result of my replication is not good. If you have time, can you try to add some tricks in this paper. I think it's very useful for us. thank you!

jarednielsen commented 4 years ago

Thanks for the feedback. ResNet isn't in active development mode, but we'll see if any of the quick tweaks improve accuracy. It may be a few weeks before there's spare bandwidth to help out though. Out of curiosity, what's your use case for extremely large-batch training?

Tron-x commented 4 years ago

thank you, I am exploring the problem of loss of precision in large batchsize, so I need a benchmark code. Another work is to design better collective communication algorithm.