google-deepmind / jax_privacy

Algorithms for Privacy-Preserving Machine Learning in JAX
Apache License 2.0
87 stars 11 forks source link

Tips on training without DP for baseline results #7

Closed steverab closed 1 year ago

steverab commented 1 year ago

Hi!

I was wondering whether you have any recommendations for running your code without DP to get a baseline model with the same architecture. I was trying to adjust the config file to:

  1. Disable clipping and unit-norm rescaling by setting them to None and False respectively.
  2. Setting the added noise to 0.
  3. Setting the maximum epsilon to None.

It seems like the comments in this config file are hinting at the fact that DP can be turned off during training, suggesting that some of these values can simply be set to None. When running using the above adaptations, the accuracy quickly goes to zero with NaN loss values. I then tried to lower the learning rate which got rid of NaNs but the model still does not perform better than chance (i.e. is stuck at 10% accuracy on CIFAR-10).

Are there any other settings I need to change in my config file or the code to disable DP and just train without privacy?

My config file looks like this:

config.experiment_kwargs = configdict.ConfigDict(
      dict(
          config=dict(
              num_updates=2468,
              optimizer=dict(
                  name='sgd',
                  lr=dict(
                      init_value=0.001,
                      decay_schedule_name=None,
                      decay_schedule_kwargs=None,
                      relative_schedule_kwargs=None,
                      ),
                  kwargs=dict(),
              ),
              model=dict(
                  model_type='wideresnet',
                  model_kwargs=dict(
                      depth=16,
                      width=4,
                  ),
                  restore=dict(
                      path=None,
                      params_key=None,
                      network_state_key=None,
                      layer_to_reset=None,
                  ),
              ),
              training=dict(
                  batch_size=dict(
                      init_value=4096,
                      per_device_per_step=64,
                      scale_schedule=None,  # example: {'2000': 8, '4000': 16},
                  ),
                  weight_decay=0.0,  # L-2 regularization,
                  train_only_layer=None,
                  dp=dict(
                      target_delta=1e-5,
                      clipping_norm=None,  # float('inf') or None to deactivate
                      stop_training_at_epsilon=None,  # None,
                      rescale_to_unit_norm=False,
                      noise=dict(
                          std_relative=0,  # noise multiplier
                          ),
                      auto_tune=None,  # 'num_updates',  # None,
                      ),
                  logging=dict(
                      grad_clipping=False,
                      grad_alignment=False,
                      snr_global=True,  # signal-to-noise ratio across layers
                      snr_per_layer=False,  # signal-to-noise ratio per layer
                  ),
              ),
              averaging=dict(
                  ema=dict(
                      coefficient=0.9999,
                      start_step=0,
                  ),
                  polyak=dict(
                      start_step=0,
                  ),
              ),
              data=dict(
                  dataset=data.get_dataset(
                      name='cifar10',
                      train_split='train_valid',  # 'train' or 'train_valid'
                      eval_split='test',  # 'valid' or 'test'
                  ),
                  random_flip=True,
                  random_crop=True,
                  augmult=0,  # implements arxiv.org/abs/2105.13343
                  ),
              evaluation=dict(
                  batch_size=500,
              ))))

Any hints towards fixing this would be highly appreciated! Thanks!

lberrada commented 1 year ago

Hi, that's right, the changes that you have made should be enough to train without differential privacy (as far as I can tell). It is now probably a matter of tuning the hyper-parameters to get the model to train correctly, and the batch-norm free architecture that we use for the WideResNet may be a bit more sensitive than the standard version.

Here are some reasonable hyper-parameters that allow the model to train correctly on my end when DP is deactivated (though not optimal by any means):

steverab commented 1 year ago

Thanks for the help!