google-deepmind / jax_privacy

Algorithms for Privacy-Preserving Machine Learning in JAX
Apache License 2.0
87 stars 11 forks source link

Computational requirements for training WRN-16-4 on CIFAR-10 #3

Closed Solosneros closed 2 years ago

Solosneros commented 2 years ago

Hi,

I was wondering if you could share insights into the compute required for training the WRN-16-4 on CIFAR-10? Training cifar10_wrn_16_4_eps1.py with a NVIDIA A100 (80GB) took around 160 minutes (~ 5.88 steps per second) on our setup.

Were your requirements in the same ballpark? I understood that you reached computation constraints when training on ImageNet, but I have not seen any other insights in your preprint.

Thanks in advance!

sohamde commented 2 years ago

Hi,

In our experiments, the cifar10_wrn_16_4_eps1.py script ran in ~50 mins using 8 devices on TPUv2 with ~2.4 steps per second. So while these numbers are not directly comparable, seems like you are getting reasonable runtimes.

Given that the A100 you are using has a lot of memory, you could possibly be able to speed up your experiments even further by increasing the batch size per_device_per_step to avoid too many accumulation steps.

Solosneros commented 2 years ago

Thanks for the rapid answer and suggestions @sohamde!