google-deepmind / jax_privacy

Algorithms for Privacy-Preserving Machine Learning in JAX
Apache License 2.0
88 stars 11 forks source link

Reproducing the result for CIFAR10 #9

Closed ahasanpour closed 1 year ago

ahasanpour commented 1 year ago

Hi!

I am trying to reproduce the results for cifar10 dataset. In Table 3. of "Unlocking High-Accuracy Differentially ..." paper, you got 79.5 % accuracy with std. dev. 0.7 and epsilon 8. But I could not reach to this accuracy. Please see the config and the output in below. I used four A100 GPUs. Should I change anything in config file? thank you in advance.


I0401 12:14:07.170631 23210343319296 train.py:73] Training with config:
best_model_eval_metric: ''
best_model_eval_metric_higher_is_better: true
checkpoint_dir: /tmp/jax_privacy/ckpt_dir
checkpoint_interval_type: null
eval_initial_weights: false
eval_specific_checkpoint_dir: ''
experiment_kwargs:
  config:
    averaging:
      ema:
        coefficient: 0.9999
        start_step: 0
      polyak:
        start_step: 0
    data:
      augmult: 16
      dataset: !!python/object:jax_privacy.src.training.image_classification.data.data_info.Dataset
        eval: !!python/object:jax_privacy.src.training.image_classification.data.data_info.Split
          num_samples: 10000
          split_content: test
        name: cifar10
        num_classes: 10
        train: !!python/object:jax_privacy.src.training.image_classification.data.data_info.Split
          num_samples: 50000
          split_content: train
      random_crop: true
      random_flip: true
    evaluation:
      batch_size: 100
    model:
      model_kwargs:
        depth: 16
        width: 4
      model_type: wideresnet
      restore:
        layer_to_reset: null
        network_state_key: null
        params_key: null
        path: null
    num_updates: 875
    optimizer:
      kwargs: {}
      lr:
        decay_schedule_kwargs: null
        decay_schedule_name: null
        init_value: 2.0
        relative_schedule_kwargs: null
      name: sgd
    training:
      batch_size:
        init_value: 4096
        per_device_per_step: 64
        scale_schedule: null
      dp:
        auto_tune: null
        clipping_norm: 1.0
        noise:
          std_relative: 0.7
        rescale_to_unit_norm: true
        stop_training_at_epsilon: 10.0
        target_delta: 1.0e-05
      logging:
        grad_alignment: false
        grad_clipping: true
        snr_global: true
        snr_per_layer: false
      train_only_layer: null
      weight_decay: 0.0
interval_type: steps
log_all_train_data: false
log_tensors_interval: 100
log_train_data_interval: 100.0
logging_interval_type: null
max_checkpoints_to_keep: 5
one_off_evaluate: false
random_mode_eval: same_host_same_device
random_mode_train: same_host_same_device
random_seed: 972286
save_checkpoint_interval: 250
train_checkpoint_all_hosts: false
training_steps: 10000

I0401 12:14:07.175778 23214846677760 train.py:152] Evaluating with config:
best_model_eval_metric: ''
best_model_eval_metric_higher_is_better: true
checkpoint_dir: /tmp/jax_privacy/ckpt_dir
checkpoint_interval_type: null
eval_initial_weights: false
eval_specific_checkpoint_dir: ''
experiment_kwargs:
  config:
    averaging:
      ema:
        coefficient: 0.9999
        start_step: 0
      polyak:
        start_step: 0
    data:
      augmult: 16
      dataset: !!python/object:jax_privacy.src.training.image_classification.data.data_info.Dataset
        eval: !!python/object:jax_privacy.src.training.image_classification.data.data_info.Split
          num_samples: 10000
          split_content: test
        name: cifar10
        num_classes: 10
        train: !!python/object:jax_privacy.src.training.image_classification.data.data_info.Split
          num_samples: 50000
          split_content: train
      random_crop: true
      random_flip: true
    evaluation:
      batch_size: 100
    model:
      model_kwargs:
        depth: 16
        width: 4
      model_type: wideresnet
      restore:
        layer_to_reset: null
        network_state_key: null
        params_key: null
        path: null
    num_updates: 875
    optimizer:
      kwargs: {}
      lr:
        decay_schedule_kwargs: null
        decay_schedule_name: null
        init_value: 2.0
        relative_schedule_kwargs: null
      name: sgd
    training:
      batch_size:
        init_value: 4096
        per_device_per_step: 64
        scale_schedule: null
      dp:
        auto_tune: null
        clipping_norm: 1.0
        noise:
          std_relative: 0.7
        rescale_to_unit_norm: true
        stop_training_at_epsilon: 10.0
        target_delta: 1.0e-05
      logging:
        grad_alignment: false
        grad_clipping: true
        snr_global: true
        snr_per_layer: false
      train_only_layer: null
      weight_decay: 0.0
interval_type: steps
log_all_train_data: false
log_tensors_interval: 100
log_train_data_interval: 100.0
logging_interval_type: null
max_checkpoints_to_keep: 5
one_off_evaluate: false
random_mode_eval: same_host_same_device
random_mode_train: same_host_same_device
random_seed: 972286
save_checkpoint_interval: 250
train_checkpoint_all_hosts: false
training_steps: 10000

/cluster/home/image_venv310/lib/python3.10/site-packages/jax/_src/lib/xla_bridge.py:429: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code.
  warnings.warn(
I0401 12:14:07.186703 23214846677760 xla_bridge.py:260] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
I0401 12:14:07.201218 23214846677760 xla_bridge.py:260] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
I0401 12:14:08.051018 23210343319296 utils.py:299] [jaxline] experiment init starting...
I0401 12:14:12.121410 23210343319296 utils.py:306] [jaxline] experiment init finished.
I0401 12:14:13.901458 23210343319296 utils.py:299] [jaxline] training loop starting...
I0401 12:14:13.902453 23214846677760 train.py:216] Checkpoint None invalid or already evaluated, waiting.
I0401 12:14:13.909003 23201798149888 dataset_info.py:565] Load dataset info from /cluster/home/tensorflow_datasets/cifar10/3.0.2
I0401 12:14:13.912754 23201798149888 dataset_builder.py:522] Reusing dataset cifar10 (/cluster/home/tensorflow_datasets/cifar10/3.0.2)
I0401 12:14:13.958257 23201798149888 logging_logger.py:49] Constructing tf.data.Dataset cifar10 for split train, from /cluster/home/tensorflow_datasets/cifar10/3.0.2
I0401 12:14:19.462527 23210343319296 experiment.py:299] Initialized parameters randomly rather than restoring from checkpoint.
I0401 12:14:23.920788 23214846677760 train.py:216] Checkpoint None invalid or already evaluated, waiting.
/cluster/home/image_venv310/lib/python3.10/site-packages/jax/_src/tree_util.py:185: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.
  warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '

/cluster/home/image_venv310/lib/python3.10/site-packages/jax/_src/lib/xla_bridge.py:429: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code.
  warnings.warn(
I0401 12:15:13.988235 23214846677760 train.py:216] Checkpoint None invalid or already evaluated, waiting.
I0401 12:15:23.998786 23214846677760 train.py:216] Checkpoint None invalid or already evaluated, waiting.
I0401 12:15:24.926065 22881994974976 train.py:38] global_step: 100, {'acc1': 20.703125, 'acc5': 76.953125, 'batch_size': 4096, 'data_seen': 25344, 'grad_norms_before_clipping_max': 21.63996696472168, 'grad_norms_before_clipping_mean': 10.72304916381836, 'grad_norms_before_clipping_median': 9.968271255493164, 'grad_norms_before_clipping_min': 6.56785249710083, 'grad_norms_before_clipping_std': 2.746678590774536, 'grads_clipped': 1.0, 'grads_norm': 1.1377094984054565, 'l2_loss': 3267.833251953125, 'learning_rate': 2.0, 'noise_std': 0.0006835937383584678, 'reg': 0.0, 'snr_global': 0.07086554169654846, 'steps_per_sec': 1.4079767981167162, 'train_loss': 2.0689990520477295, 'train_loss_max': 3.550971746444702, 'train_loss_mean': 2.0689990520477295, 'train_loss_median': 2.105243682861328, 'train_loss_min': 0.9171600341796875, 'train_loss_std': 0.417292982339859, 'train_obj': 2.0689990520477295, 'update_every': 16, 'update_step': 6}
I0401 12:15:34.009112 23214846677760 train.py:216] Checkpoint None invalid or already evaluated, waiting.
I0401 12:15:44.019091 23214846677760 train.py:216] Checkpoint None invalid or already evaluated, waiting.
I0401 12:15:44.721147 22881994974976 train.py:38] global_step: 200, {'acc1': 31.640625, 'acc5': 82.03125, 'batch_size': 4096, 'data_seen': 50944, 'grad_norms_before_clipping_max': 16.47217559814453, 'grad_norms_before_clipping_mean': 9.767989158630371, 'grad_norms_before_clipping_median': 9.091947555541992, 'grad_norms_before_clipping_min': 7.483264923095703, 'grad_norms_before_clipping_std': 1.6712777614593506, 'grads_clipped': 1.0, 'grads_norm': 1.1372840404510498, 'l2_loss': 3269.246826171875, 'learning_rate': 2.0, 'noise_std': 0.0006835937383584678, 'reg': 0.0, 'snr_global': 0.06725210696458817, 'steps_per_sec': 5.051736012134754, 'train_loss': 2.05218505859375, 'train_loss_max': 2.9698939323425293, 'train_loss_mean': 2.05218505859375, 'train_loss_median': 2.0428333282470703, 'train_loss_min': 1.2016706466674805, 'train_loss_std': 0.2753119170665741, 'train_obj': 2.05218505859375, 'update_every': 16, 'update_step': 12}
I0401 12:15:54.029448 23214846677760 train.py:216] Checkpoint None invalid or already evaluated, waiting.
I0401 12:15:54.625996 23210343319296 utils.py:578] Saved checkpoint latest with id 0.
I0401 12:16:04.039178 23214846677760 utils.py:590] Returned checkpoint latest with id 0.
I0401 12:16:04.291793 23214846677760 dataset_info.py:565] Load dataset info from /cluster/home/tensorflow_datasets/cifar10/3.0.2
I0401 12:16:04.316308 23214846677760 dataset_builder.py:522] Reusing dataset cifar10 (/cluster/home/tensorflow_datasets/cifar10/3.0.2)
I0401 12:16:04.528922 23214846677760 logging_logger.py:49] Constructing tf.data.Dataset cifar10 for split test, from /cluster/home/tensorflow_datasets/cifar10/3.0.2
I0401 12:16:04.748635 22881994974976 train.py:38] global_step: 300, {'acc1': 26.953125, 'acc5': 67.96875, 'batch_size': 4096, 'data_seen': 76544, 'grad_norms_before_clipping_max': 31.99955940246582, 'grad_norms_before_clipping_mean': 14.536386489868164, 'grad_norms_before_clipping_median': 13.962177276611328, 'grad_norms_before_clipping_min': 6.895867824554443, 'grad_norms_before_clipping_std': 4.881536960601807, 'grads_clipped': 1.0, 'grads_norm': 1.1594922542572021, 'l2_loss': 3270.818603515625, 'learning_rate': 2.0, 'noise_std': 0.0006835937383584678, 'reg': 0.0, 'snr_global': 0.20909515023231506, 'steps_per_sec': 4.9934275826661905, 'train_loss': 2.0765316486358643, 'train_loss_max': 5.022525787353516, 'train_loss_mean': 2.0765316486358643, 'train_loss_median': 1.964513897895813, 'train_loss_min': 0.4019668400287628, 'train_loss_std': 0.9520715475082397, 'train_obj': 2.0765316486358643, 'update_every': 16, 'update_step': 18}
/cluster/home/image_venv310/lib/python3.10/site-packages/jax/_src/lib/xla_bridge.py:429: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code.
  warnings.warn(
I0401 12:16:19.459383 23214846677760 experiment.py:346] {'acc1_ema': 26.82000732421875, 'acc1_last': 26.46001434326172, 'acc1_polyak': 36.630035400390625, 'acc5_ema': 76.58000183105469, 'acc5_last': 76.07000732421875, 'acc5_polyak': 86.07999420166016, 'dp_epsilon': 7.503825433730629, 'loss_ema': 1.9822626113891602, 'loss_last': 1.993444800376892, 'loss_polyak': 1.812673568725586, 'num_samples': 10000, 'update_step': 18}
I0401 12:16:19.495003 23214846677760 train.py:216] Checkpoint 0 invalid or already evaluated, waiting.
I0401 12:16:26.012063 22881994974976 train.py:38] global_step: 400, {'acc1': 30.46875, 'acc5': 81.25, 'batch_size': 4096, 'data_seen': 102144, 'grad_norms_before_clipping_max': 27.274269104003906, 'grad_norms_before_clipping_mean': 13.5658540725708, 'grad_norms_before_clipping_median': 12.410638809204102, 'grad_norms_before_clipping_min': 6.213404655456543, 'grad_norms_before_clipping_std': 3.9092812538146973, 'grads_clipped': 1.0, 'grads_norm': 1.1419119834899902, 'l2_loss': 3272.21630859375, 'learning_rate': 2.0, 'noise_std': 0.0006835937383584678, 'reg': 0.0, 'snr_global': 0.11373414099216461, 'steps_per_sec': 4.702659360694347, 'train_loss': 1.863084316253662, 'train_loss_max': 5.417091369628906, 'train_loss_mean': 1.8630844354629517, 'train_loss_median': 1.7905535697937012, 'train_loss_min': 0.32042211294174194, 'train_loss_std': 0.9155722856521606, 'train_obj': 1.863084316253662, 'update_every': 16, 'update_step': 24}
I0401 12:16:29.505634 23214846677760 train.py:216] Checkpoint 0 invalid or already evaluated, waiting.
I0401 12:16:39.515980 23214846677760 train.py:216] Checkpoint 0 invalid or already evaluated, waiting.
I0401 12:16:45.776562 23210343319296 utils.py:578] Saved checkpoint latest with id 1.
I0401 12:16:45.776842 22881994974976 train.py:38] global_step: 500, {'acc1': 38.28125, 'acc5': 85.15625, 'batch_size': 4096, 'data_seen': 127744, 'grad_norms_before_clipping_max': 24.741037368774414, 'grad_norms_before_clipping_mean': 13.700193405151367, 'grad_norms_before_clipping_median': 13.279705047607422, 'grad_norms_before_clipping_min': 8.776347160339355, 'grad_norms_before_clipping_std': 3.278830051422119, 'grads_clipped': 1.0, 'grads_norm': 1.1389127969741821, 'l2_loss': 3273.900634765625, 'learning_rate': 2.0, 'noise_std': 0.0006835937383584678, 'reg': 0.0, 'snr_global': 0.09047579765319824, 'steps_per_sec': 5.059544740082135, 'train_loss': 1.7541322708129883, 'train_loss_max': 4.324624538421631, 'train_loss_mean': 1.7541322708129883, 'train_loss_median': 1.6827952861785889, 'train_loss_min': 0.22405478358268738, 'train_loss_std': 0.7736775279045105, 'train_obj': 1.7541322708129883, 'update_every': 16, 'update_step': 31}
I0401 12:16:49.529292 23214846677760 utils.py:590] Returned checkpoint latest with id 1.
I0401 12:16:49.842677 23214846677760 dataset_info.py:565] Load dataset info from /cluster/home/tensorflow_datasets/cifar10/3.0.2
I0401 12:16:49.846248 23214846677760 dataset_builder.py:522] Reusing dataset cifar10 (/cluster/home/tensorflow_datasets/cifar10/3.0.2)
I0401 12:16:49.883034 23214846677760 logging_logger.py:49] Constructing tf.data.Dataset cifar10 for split test, from /cluster/home/tensorflow_datasets/cifar10/3.0.2
I0401 12:17:03.920780 23214846677760 experiment.py:346] {'acc1_ema': 43.09000778198242, 'acc1_last': 42.329994201660156, 'acc1_polyak': 43.16001510620117, 'acc5_ema': 89.19002532958984, 'acc5_last': 88.89002227783203, 'acc5_polyak': 89.4800033569336, 'dp_epsilon': 8.977914832429319, 'loss_ema': 1.6547045707702637, 'loss_last': 1.6644262075424194, 'loss_polyak': 1.6637684106826782, 'num_samples': 10000, 'update_step': 32}
I0401 12:17:04.116245 23214846677760 train.py:216] Checkpoint 1 invalid or already evaluated, waiting.
I0401 12:17:06.270281 22881994974976 train.py:38] global_step: 600, {'acc1': 33.59375, 'acc5': 80.859375, 'batch_size': 4096, 'data_seen': 153344, 'grad_norms_before_clipping_max': 34.96406555175781, 'grad_norms_before_clipping_mean': 16.682771682739258, 'grad_norms_before_clipping_median': 16.687217712402344, 'grad_norms_before_clipping_min': 2.4778926372528076, 'grad_norms_before_clipping_std': 7.029008388519287, 'grads_clipped': 1.0, 'grads_norm': 1.1422324180603027, 'l2_loss': 3275.451904296875, 'learning_rate': 2.0, 'noise_std': 0.0006835937383584678, 'reg': 0.0, 'snr_global': 0.12483368813991547, 'steps_per_sec': 4.87956832468756, 'train_loss': 1.9203698635101318, 'train_loss_max': 5.310922145843506, 'train_loss_mean': 1.9203698635101318, 'train_loss_median': 1.692842721939087, 'train_loss_min': 0.07809478044509888, 'train_loss_std': 1.1521713733673096, 'train_obj': 1.9203698635101318, 'update_every': 16, 'update_step': 37}
I0401 12:17:14.126081 23214846677760 train.py:216] Checkpoint 1 invalid or already evaluated, waiting.
I0401 12:17:24.136556 23214846677760 train.py:216] Checkpoint 1 invalid or already evaluated, waiting.
I0401 12:17:26.028420 22881994974976 train.py:38] global_step: 700, {'acc1': 41.015625, 'acc5': 89.0625, 'batch_size': 4096, 'data_seen': 178944, 'grad_norms_before_clipping_max': 45.206146240234375, 'grad_norms_before_clipping_mean': 16.352386474609375, 'grad_norms_before_clipping_median': 15.187183380126953, 'grad_norms_before_clipping_min': 2.9148075580596924, 'grad_norms_before_clipping_std': 7.583418846130371, 'grads_clipped': 1.0, 'grads_norm': 1.1369668245315552, 'l2_loss': 3276.95556640625, 'learning_rate': 2.0, 'noise_std': 0.0006835937383584678, 'reg': 0.0, 'snr_global': 0.07157699763774872, 'steps_per_sec': 5.06121985020044, 'train_loss': 1.5519847869873047, 'train_loss_max': 4.772342205047607, 'train_loss_mean': 1.5519846677780151, 'train_loss_median': 1.3451415300369263, 'train_loss_min': 0.0608828142285347, 'train_loss_std': 0.9366973638534546, 'train_obj': 1.5519847869873047, 'update_every': 16, 'update_step': 43}
I0401 12:17:29.953248 23210343319296 utils.py:306] [jaxline] training loop finished.
I0401 12:17:29.955536 23210343319296 utils.py:299] [jaxline] final checkpoint starting...
I0401 12:17:29.955831 23210343319296 utils.py:578] Saved checkpoint latest with id 2.
I0401 12:17:29.955891 23210343319296 utils.py:306] [jaxline] final checkpoint finished.
I0401 12:17:29.955952 23210343319296 utils.py:299] [jaxline] rendezvous starting...
I0401 12:17:30.093008 23210343319296 utils.py:306] [jaxline] rendezvous finished.
I0401 12:17:34.149858 23214846677760 utils.py:590] Returned checkpoint latest with id 2.
I0401 12:17:34.301443 23214846677760 dataset_info.py:565] Load dataset info from /cluster/home/tensorflow_datasets/cifar10/3.0.2
I0401 12:17:34.304268 23214846677760 dataset_builder.py:522] Reusing dataset cifar10 (/cluster/home/tensorflow_datasets/cifar10/3.0.2)
I0401 12:17:34.473762 23214846677760 logging_logger.py:49] Constructing tf.data.Dataset cifar10 for split test, from /cluster/home/tensorflow_datasets/cifar10/3.0.2
I0401 12:17:35.382286 23214846677760 experiment.py:346] {'acc1_ema': 46.16999435424805, 'acc1_last': 37.900001525878906, 'acc1_polyak': 46.110015869140625, 'acc5_ema': 91.530029296875, 'acc5_last': 83.7099838256836, 'acc5_polyak': 91.4000015258789, 'dp_epsilon': 10.106488607032036, 'loss_ema': 1.4954969882965088, 'loss_last': 1.7757163047790527, 'loss_polyak': 1.5072647333145142, 'num_samples': 10000, 'update_step': 45}
I0401 12:17:35.389587 23214846677760 train.py:251] Last checkpoint (iteration 720) evaluated, exiting.```  
ahasanpour commented 1 year ago

I think It's solved, you can find the log here!

I changed the std_relative:3.0 as stated in the paper!

lberrada commented 1 year ago

Yes indeed, the stdev mentioned in Table 3 is the standard deviation of the accuracy, not the noise multiplier. Great that you were able to reproduce the results! I will close this.