google-deepmind / jax_privacy

Algorithms for Privacy-Preserving Machine Learning in JAX
Apache License 2.0
88 stars 11 forks source link

Multithread and Batch Size Issue #19

Closed SabrinaMokhtari closed 10 months ago

SabrinaMokhtari commented 11 months ago

Hello there!

I'm currently working on the Image Classification experiments.

When trying to run the code for CIFAR fine-tuning or training from scratch on CIFAR-10, using --jaxline_mode=train_eval_multithreaded triggers this error:

I1206 22:11:17.449106 140003988469312 utils.py:616] Saved checkpoint latest with id 0.
I1206 22:11:25.352441 140009254860608 utils.py:628] Returned checkpoint latest with id 0.
Traceback (most recent call last):
  File "/jax/jax_privacy/jax_privacy/experiments/image_classification/run_experiment.py", line 32, in <module>
    app.run(functools.partial(platform.main, experiment.Experiment))
  File "/.conda/envs/tf/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/.conda/envs/tf/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/.conda/envs/tf/lib/python3.10/site-packages/jaxline/utils.py", line 522, in inner_wrapper
    return f(*args, **kwargs)
  File "/.conda/envs/tf/lib/python3.10/site-packages/jaxline/platform.py", line 148, in main
    train.evaluate(experiment_class, config,
  File "/.conda/envs/tf/lib/python3.10/site-packages/jaxline/utils.py", line 658, in inner_wrapper
    return fn(*args, **kwargs)
  File "/.conda/envs/tf/lib/python3.10/site-packages/jaxline/train.py", line 240, in evaluate
    scalar_values = utils.evaluate_should_return_dict(experiment.evaluate)(
  File "/.conda/envs/tf/lib/python3.10/site-packages/jaxline/utils.py", line 559, in evaluate_with_warning
    evaluate_out = f(*args, **kwargs)
  File "/jax/jax_privacy/jax_privacy/src/training/experiment.py", line 335, in evaluate
    self._step_count = self.updater.step_count_from_opt_state(
  File "/jax/jax_privacy/jax_privacy/src/training/dp_updater.py", line 168, in step_count_from_opt_state
    update_step=int(jaxline_utils.get_first(opt_state.gradient_step)),
  File "/.conda/envs/tf/lib/python3.10/site-packages/jax/_src/array.py", line 276, in __int__
    return self._value.__int__()
  File "/.conda/envs/tf/lib/python3.10/site-packages/jax/_src/profiler.py", line 334, in wrapper
    return func(*args, **kwargs)
  File "/.conda/envs/tf/lib/python3.10/site-packages/jax/_src/array.py", line 581, in _value
    self._npy_value = self._single_device_array_to_np_array()  # type: ignore
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: CopyToHostAsync() called on deleted or donated buffer
E1206 22:11:28.653215 140003988469312 utils.py:310] [jaxline] training loop failed with error.
Traceback (most recent call last):
  File "/.conda/envs/tf/lib/python3.10/site-packages/jaxline/utils.py", line 307, in log_activity
    yield
  File "/.conda/envs/tf/lib/python3.10/site-packages/jaxline/experiment.py", line 211, in train_loop
    action(t, state.global_step, scalar_outputs)  # pytype: disable=wrong-arg-types  # jax-ndarray
  File "/.conda/envs/tf/lib/python3.10/site-packages/jaxline/utils.py", line 507, in __call__
    self._apply_fn_future = self._apply_fn(  # pylint: disable=not-callable
  File "/.conda/envs/tf/lib/python3.10/site-packages/jaxline/utils.py", line 292, in decorator
    return pool.submit(trap_errors, *args, **kwargs)
  File "/.conda/envs/tf/lib/python3.10/concurrent/futures/thread.py", line 167, in submit
    raise RuntimeError('cannot schedule new futures after shutdown')
RuntimeError: cannot schedule new futures after shutdown

However, removing this mode seems to prevent the code from going through the evaluation step, or at least, the logs don't capture it. All the logs seem to pertain to training, following the 'train/parameter' format as illustrated below.

On another note, I'm curious about the batch size. The default hyperparameters suggest a training batch size of 4096 and a per_device_per_step of 64. However, according to the logs, the data_seen at each step is 6400 which seems to align with global steps per_device_per_step (100 64) but not with the original batch size. Another confusing factor is the train/update_step, which appears to lack a specific order (it goes from 1 to 3, 4, 6, 7, ...). Understanding the relationship between batch size (4096), per_device_per_step (64), the number of global steps (100), update_step, and epochs has become a bit of a puzzle.

Additionally, I've noticed that the dp_epsilon value changes only after 7 log prints. I presume this behavior originates from taking virtual steps and working with a virtual batch size. However, making sense of this and the various numbers mentioned earlier is proving to be quite a challenge. Any insights or clarifications would be highly appreciated!

I1206 21:45:46.449987 140215121851968 train.py:40] global_step: 100, {'steps_per_sec': 2.0826217976325334, 'train/acc1': 14.0625, 'train/acc5': 57.8125, 'train/batch_size': 4096, 'train/data_seen': 6336, 'train/dp_epsilon': 0.08617811428309091, 'train/epochs': 
0.12671999633312225, 'train/grad_norms_before_clipping_max': 11.097847938537598, 'train/grad_norms_before_clipping_mean': 7.421786308288574, 'train/grad_norms_before_clipping_median': 7.386024475097656, 'train/grad_norms_before_clipping_min': 6.012624740600586,
 'train/grad_norms_before_clipping_std': 0.9031921625137329, 'train/grads_clipped': 1.0, 'train/grads_norm': 4.053260326385498, 'train/l2_loss': 3307.333740234375, 'train/learning_rate': 2.0, 'train/loss': 2.2892117500305176, 'train/loss_max': 2.475972175598144
5, 'train/loss_mean': 2.2892119884490967, 'train/loss_median': 2.306485176086426, 'train/loss_min': 2.0098161697387695, 'train/loss_std': 0.09207069128751755, 'train/noise_std': 0.00244140625, 'train/obj': 2.2892117500305176, 'train/reg': 0.0, 'train/snr_global
': 0.03148417919874191, 'train/update_every': 64, 'train/update_step': 1}                                                                                                                                                                                            
I1206 21:46:08.849350 140215121851968 train.py:40] global_step: 200, {'steps_per_sec': 4.465434843913596, 'train/acc1': 17.1875, 'train/acc5': 71.875, 'train/batch_size': 4096, 'train/data_seen': 12736, 'train/dp_epsilon': 0.08617811428309091, 'train/epochs': 0
.254720002412796, 'train/grad_norms_before_clipping_max': 12.894222259521484, 'train/grad_norms_before_clipping_mean': 7.597695350646973, 'train/grad_norms_before_clipping_median': 7.392450332641602, 'train/grad_norms_before_clipping_min': 6.075026512145996, 't
rain/grad_norms_before_clipping_std': 1.1005326509475708, 'train/grads_clipped': 1.0, 'train/grads_norm': 4.054327964782715, 'train/l2_loss': 3373.004150390625, 'train/learning_rate': 2.0, 'train/loss': 2.2097840309143066, 'train/loss_max': 2.4616034030914307, 
'train/loss_mean': 2.2097842693328857, 'train/loss_median': 2.217123031616211, 'train/loss_min': 1.8365141153335571, 'train/loss_std': 0.12813915312290192, 'train/noise_std': 0.00244140625, 'train/obj': 2.2097840309143066, 'train/reg': 0.0, 'train/snr_global': 
0.03480687364935875, 'train/update_every': 64, 'train/update_step': 3}
I1206 21:46:20.105232 140274190288704 utils.py:616] Saved checkpoint latest with id 0.
I1206 21:46:31.318336 140215121851968 train.py:40] global_step: 300, {'steps_per_sec': 4.450265468697612, 'train/acc1': 28.125, 'train/acc5': 70.3125, 'train/batch_size': 4096, 'train/data_seen': 19136, 'train/dp_epsilon': 0.08617811428309091, 'train/epochs': 0
.3827199935913086, 'train/grad_norms_before_clipping_max': 12.506253242492676, 'train/grad_norms_before_clipping_mean': 7.85102653503418, 'train/grad_norms_before_clipping_median': 7.735785961151123, 'train/grad_norms_before_clipping_min': 5.919890403747559, 't
rain/grad_norms_before_clipping_std': 1.08476984500885, 'train/grads_clipped': 1.0, 'train/grads_norm': 4.053022861480713, 'train/l2_loss': 3405.334716796875, 'train/learning_rate': 2.0, 'train/loss': 2.1913857460021973, 'train/loss_max': 2.5004332065582275, 't
rain/loss_mean': 2.1913857460021973, 'train/loss_median': 2.198826789855957, 'train/loss_min': 1.8431569337844849, 'train/loss_std': 0.14818917214870453, 'train/noise_std': 0.00244140625, 'train/obj': 2.1913857460021973, 'train/reg': 0.0, 'train/snr_global': 0.
03667474538087845, 'train/update_every': 64, 'train/update_step': 4}
I1206 21:46:53.617524 140215121851968 train.py:40] global_step: 400, {'steps_per_sec': 4.484427345599816, 'train/acc1': 21.875, 'train/acc5': 82.8125, 'train/batch_size': 4096, 'train/data_seen': 25536, 'train/dp_epsilon': 0.08617811428309091, 'train/epochs': 0
.5107200145721436, 'train/grad_norms_before_clipping_max': 17.501323699951172, 'train/grad_norms_before_clipping_mean': 9.197774887084961, 'train/grad_norms_before_clipping_median': 8.795980453491211, 'train/grad_norms_before_clipping_min': 6.563833236694336, '
train/grad_norms_before_clipping_std': 1.8294460773468018, 'train/grads_clipped': 1.0, 'train/grads_norm': 4.054222583770752, 'train/l2_loss': 3470.687744140625, 'train/learning_rate': 2.0, 'train/loss': 2.0719828605651855, 'train/loss_max': 2.994951009750366, 
'train/loss_mean': 2.0719833374023438, 'train/loss_median': 2.0422544479370117, 'train/loss_min': 1.4358981847763062, 'train/loss_std': 0.2828764319419861, 'train/noise_std': 0.00244140625, 'train/obj': 2.0719828605651855, 'train/reg': 0.0, 'train/snr_global': 
0.03867485374212265, 'train/update_every': 64, 'train/update_step': 6}
I1206 21:47:16.053574 140274190288704 utils.py:616] Saved checkpoint latest with id 1.
I1206 21:47:16.053966 140215121851968 train.py:40] global_step: 500, {'steps_per_sec': 4.453045586507099, 'train/acc1': 34.375, 'train/acc5': 85.9375, 'train/batch_size': 4096, 'train/data_seen': 31936, 'train/dp_epsilon': 0.08617811428309091, 'train/epochs': 0
.6387199759483337, 'train/grad_norms_before_clipping_max': 15.863215446472168, 'train/grad_norms_before_clipping_mean': 9.66342830657959, 'train/grad_norms_before_clipping_median': 9.340354919433594, 'train/grad_norms_before_clipping_min': 6.10451602935791, 'tr
ain/grad_norms_before_clipping_std': 1.9625134468078613, 'train/grads_clipped': 1.0, 'train/grads_norm': 4.052708625793457, 'train/l2_loss': 3503.772705078125, 'train/learning_rate': 2.0, 'train/loss': 2.0133843421936035, 'train/loss_max': 3.1231155395507812, '
train/loss_mean': 2.0133841037750244, 'train/loss_median': 1.995213508605957, 'train/loss_min': 1.2569082975387573, 'train/loss_std': 0.3511837422847748, 'train/noise_std': 0.00244140625, 'train/obj': 2.0133843421936035, 'train/reg': 0.0, 'train/snr_global': 0.
03174315765500069, 'train/update_every': 64, 'train/update_step': 7}
I1206 21:47:38.578334 140215121851968 train.py:40] global_step: 600, {'steps_per_sec': 4.443082685220641, 'train/acc1': 17.1875, 'train/acc5': 59.375, 'train/batch_size': 4096, 'train/data_seen': 38336, 'train/dp_epsilon': 0.08617811428309091, 'train/epochs': 0
.7667199969291687, 'train/grad_norms_before_clipping_max': 26.938108444213867, 'train/grad_norms_before_clipping_mean': 15.219337463378906, 'train/grad_norms_before_clipping_median': 15.87930679321289, 'train/grad_norms_before_clipping_min': 9.131158828735352, 
'train/grad_norms_before_clipping_std': 4.187152862548828, 'train/grads_clipped': 1.0, 'train/grads_norm': 4.069787502288818, 'train/l2_loss': 3569.447509765625, 'train/learning_rate': 2.0, 'train/loss': 2.7184159755706787, 'train/loss_max': 5.833471298217773, 
'train/loss_mean': 2.7184157371520996, 'train/loss_median': 2.735433578491211, 'train/loss_min': 0.6383439302444458, 'train/loss_std': 1.306814193725586, 'train/noise_std': 0.00244140625, 'train/obj': 2.7184159755706787, 'train/reg': 0.0, 'train/snr_global': 0.
09964726865291595, 'train/update_every': 64, 'train/update_step': 9}
I1206 21:48:01.223080 140215121851968 train.py:40] global_step: 700, {'steps_per_sec': 4.418504996894261, 'train/acc1': 12.5, 'train/acc5': 65.625, 'train/batch_size': 4096, 'train/data_seen': 44736, 'train/dp_epsilon': 0.08617811428309091, 'train/epochs': 0.89
47199583053589, 'train/grad_norms_before_clipping_max': 16.224781036376953, 'train/grad_norms_before_clipping_mean': 11.527853012084961, 'train/grad_norms_before_clipping_median': 11.496959686279297, 'train/grad_norms_before_clipping_min': 7.3519721031188965, '
train/grad_norms_before_clipping_std': 1.7476364374160767, 'train/grads_clipped': 1.0, 'train/grads_norm': 4.0550971031188965, 'train/l2_loss': 3602.263671875, 'train/learning_rate': 2.0, 'train/loss': 2.2895689010620117, 'train/loss_max': 3.676800012588501, 't
rain/loss_mean': 2.2895689010620117, 'train/loss_median': 1.953697919845581, 'train/loss_min': 1.4988765716552734, 'train/loss_std': 0.6385901570320129, 'train/noise_std': 0.00244140625, 'train/obj': 2.2895689010620117, 'train/reg': 0.0, 'train/snr_global': 0.0
4279553145170212, 'train/update_every': 64, 'train/update_step': 10}
I1206 21:48:12.545383 140274190288704 utils.py:616] Saved checkpoint latest with id 2.
I1206 21:48:23.718779 140215121851968 train.py:40] global_step: 800, {'steps_per_sec': 4.444768246350751, 'train/acc1': 28.125, 'train/acc5': 82.8125, 'train/batch_size': 4096, 'train/data_seen': 51136, 'train/dp_epsilon': 0.12439736664167292, 'train/epochs': 1
.0227199792861938, 'train/grad_norms_before_clipping_max': 13.829648971557617, 'train/grad_norms_before_clipping_mean': 8.596567153930664, 'train/grad_norms_before_clipping_median': 8.216176986694336, 'train/grad_norms_before_clipping_min': 6.665882587432861, '
train/grad_norms_before_clipping_std': 1.405860185623169, 'train/grads_clipped': 1.0, 'train/grads_norm': 4.056445121765137, 'train/l2_loss': 3667.319091796875, 'train/learning_rate': 2.0, 'train/loss': 2.097318172454834, 'train/loss_max': 2.563758611679077, 't
rain/loss_mean': 2.097318172454834, 'train/loss_median': 2.113304376602173, 'train/loss_min': 1.6178672313690186, 'train/loss_std': 0.19550736248493195, 'train/noise_std': 0.00244140625, 'train/obj': 2.097318172454834, 'train/reg': 0.0, 'train/snr_global': 0.04
158717393875122, 'train/update_every': 64, 'train/update_step': 12}

Thanks!

lberrada commented 11 months ago

Hi, sorry about the slow reply.

Here are some answers that will hopefully clarify things:

  1. you're correct that --jaxline_mode=train_eval_multithreaded is required to run evaluation, otherwise only training is applied. This is an artefact of the JAXline framework that is currently used.
  2. global_step is a counter that gets incremented at each step, whether computing a gradient accumulation step for the virtual batch-size or actually performing an update. The name is unfortunate, but this is also something that JAXline enforces, and that we cannot easily modify.
  3. update_step is indeed incremented whenever a model update is performed.
  4. the logs are not shown for each step, which is why you observe the strange "skipping" pattern. If you set log_train_data_interval=1 here, you should see the logs at every global_step and the pattern should become clearer.
  5. Another subtlety about dp_epsilon is that we actually pre-compute values and cache them on a regular grid of the steps. This avoids running slow accounting operations on CPU at each training step, which would slow things down. So dp_epsilon is actually a piecewise constant approximation as a function of the steps, where each cached value within a constant piece should be an upper bound on the "exact" dp_epsilon.

We're aware that global_step vs update_step is confusing, and that the eval setup is not ideal; these issues have been stemming from constraints of JAXline. The good news though is these confusing bits mentioned so far (eval being difficult to run, global step vs update step) will go away in a forthcoming version that I hope to release soon!

SabrinaMokhtari commented 11 months ago

Hello,

Thanks for sharing the information; it really helped clarify things. Hearing about the improvements in the latest version is great news! Do you have an estimated release timeframe for the new version?

In the meantime, however, using multi-thread mode remains essential for having training and evaluation together. Therefore, I want to explore potential solutions for the error mentioned earlier. Could this be attributed to library incompatibility? Given the intricate dependencies among Cudatoolkit, Cudnn, Jax, Jaxlib, Tensorflow, and other components, it's plausible that version mismatches might cause this error. Could you kindly provide details regarding the specific versions of these libraries used in the experiments? Alternatively, is there another potential cause for this error?

Thanks once again for your assistance.

lberrada commented 10 months ago

The new version is now available with #20 , hopefully this will side-step the train-eval issue.

SabrinaMokhtari commented 10 months ago

Hello,

Thank you! This has resolved many of my issues, and the code seems to be running smoothly.

I do have an additional question regarding the newly added dataset, CheXpert, in the latest version. I observed that the CheXpert dataset and dataloader configurations are available in the CheXpert folder. However, I am unable to find a configuration file similar to the one you have provided for Cifar-10 in the configs folder. I'm looking for a file where I can reference the base hyper-parameters and configuration values for CheXpert to replicate the results you have in the Unlocking Accuracy and Fairness in Differentially Private Image Classification.

Once again, I appreciate your help. Thank you!

lberrada commented 10 months ago

Hi, great to hear that the new version is running smoothly. The hyper-parameters to reproduce our results are detailed in appendix section C.5 of the paper. We do not provide a config for every single experiment because that is not manageable from a maintenance point of view, but it should be fairly easy to adapt existing configs for different experiments.

Note that for CheXpert, you will also need to manage the dataset downloading and local loading on your end, because it requires a special license (all details should be on their website https://stanfordaimi.azurewebsites.net/datasets/8cbd9ed4-2eb9-4565-affc-111cf4f7ebe2).

SabrinaMokhtari commented 10 months ago

Hello, Yes, the hyperparameters outlined in the paper appear to be covering everything. I appreciate you sharing information about how to get access to CheXpert. Thank you so much for your help. I will close this issue.