google-deepmind / jax_privacy

Algorithms for Privacy-Preserving Machine Learning in JAX
Apache License 2.0
87 stars 11 forks source link

How to speedup the running? #13

Closed heilrahc closed 1 year ago

heilrahc commented 1 year ago

I'm trying to reproduce the results: WRN28-10 pre-trained. But it runs very slowly, I ran it for 7 days but it only reaches dp-epsilon: 1.2. I can't imagine how much time it'd take to reach dp-epsilon 8.

I'm using an NVIDIA Tesla V100 32GB. And the config file I used is the default one on GitHub.

Is it supposed to take this long for it to run? Or is there a way to speed this up?

lberrada commented 1 year ago

Hi, this experiment is indeed expected to take a long time, because the hyper-parameters were optimized for pure accuracy regardless of a potential computational trade-off.

With that said, your running times do seem a few times longer than expected. Some suggestions:

  1. We ran our experiment on different hardware, so it may be that a larger batch_size.per_device_per_step can fit in memory on your GPU, which could increase the number of samples processed per second -- overall it's probably worth trying to tune that number to optimize the throughput on your specific device.
  2. Check that the GPU utilization is high during training, to ensure in particular that the pipeline is not bottlenecked by data loading.
  3. If the experiment remains too long, it is possible to trade-off some compute cost for accuracy, by e.g. reducing the number of augmentations set by augmult: in this fine-tuning setting, it is quite likely that the returns are somewhat small. This would not perfectly reproduce our results, but the resulting accuracy should remain somewhat close.