Closed SabrinaMokhtari closed 10 months ago
Hi, sorry about the slow reply.
Here are some answers that will hopefully clarify things:
--jaxline_mode=train_eval_multithreaded
is required to run evaluation, otherwise only training is applied. This is an artefact of the JAXline framework that is currently used.global_step
is a counter that gets incremented at each step, whether computing a gradient accumulation step for the virtual batch-size or actually performing an update. The name is unfortunate, but this is also something that JAXline enforces, and that we cannot easily modify.update_step
is indeed incremented whenever a model update is performed.log_train_data_interval=1
here, you should see the logs at every global_step
and the pattern should become clearer.dp_epsilon
is that we actually pre-compute values and cache them on a regular grid of the steps. This avoids running slow accounting operations on CPU at each training step, which would slow things down. So dp_epsilon
is actually a piecewise constant approximation as a function of the steps, where each cached value within a constant piece should be an upper bound on the "exact" dp_epsilon
.We're aware that global_step
vs update_step
is confusing, and that the eval setup is not ideal; these issues have been stemming from constraints of JAXline. The good news though is these confusing bits mentioned so far (eval being difficult to run, global step vs update step) will go away in a forthcoming version that I hope to release soon!
Hello,
Thanks for sharing the information; it really helped clarify things. Hearing about the improvements in the latest version is great news! Do you have an estimated release timeframe for the new version?
In the meantime, however, using multi-thread mode remains essential for having training and evaluation together. Therefore, I want to explore potential solutions for the error mentioned earlier. Could this be attributed to library incompatibility? Given the intricate dependencies among Cudatoolkit, Cudnn, Jax, Jaxlib, Tensorflow, and other components, it's plausible that version mismatches might cause this error. Could you kindly provide details regarding the specific versions of these libraries used in the experiments? Alternatively, is there another potential cause for this error?
Thanks once again for your assistance.
The new version is now available with #20 , hopefully this will side-step the train-eval issue.
Hello,
Thank you! This has resolved many of my issues, and the code seems to be running smoothly.
I do have an additional question regarding the newly added dataset, CheXpert, in the latest version. I observed that the CheXpert dataset and dataloader configurations are available in the CheXpert folder. However, I am unable to find a configuration file similar to the one you have provided for Cifar-10 in the configs folder. I'm looking for a file where I can reference the base hyper-parameters and configuration values for CheXpert to replicate the results you have in the Unlocking Accuracy and Fairness in Differentially Private Image Classification.
Once again, I appreciate your help. Thank you!
Hi, great to hear that the new version is running smoothly. The hyper-parameters to reproduce our results are detailed in appendix section C.5 of the paper. We do not provide a config for every single experiment because that is not manageable from a maintenance point of view, but it should be fairly easy to adapt existing configs for different experiments.
Note that for CheXpert, you will also need to manage the dataset downloading and local loading on your end, because it requires a special license (all details should be on their website https://stanfordaimi.azurewebsites.net/datasets/8cbd9ed4-2eb9-4565-affc-111cf4f7ebe2).
Hello, Yes, the hyperparameters outlined in the paper appear to be covering everything. I appreciate you sharing information about how to get access to CheXpert. Thank you so much for your help. I will close this issue.
Hello there!
I'm currently working on the Image Classification experiments.
When trying to run the code for CIFAR fine-tuning or training from scratch on CIFAR-10, using --jaxline_mode=train_eval_multithreaded triggers this error:
However, removing this mode seems to prevent the code from going through the evaluation step, or at least, the logs don't capture it. All the logs seem to pertain to training, following the 'train/parameter' format as illustrated below.
On another note, I'm curious about the batch size. The default hyperparameters suggest a training batch size of 4096 and a per_device_per_step of 64. However, according to the logs, the data_seen at each step is 6400 which seems to align with global steps per_device_per_step (100 64) but not with the original batch size. Another confusing factor is the train/update_step, which appears to lack a specific order (it goes from 1 to 3, 4, 6, 7, ...). Understanding the relationship between batch size (4096), per_device_per_step (64), the number of global steps (100), update_step, and epochs has become a bit of a puzzle.
Additionally, I've noticed that the dp_epsilon value changes only after 7 log prints. I presume this behavior originates from taking virtual steps and working with a virtual batch size. However, making sense of this and the various numbers mentioned earlier is proving to be quite a challenge. Any insights or clarifications would be highly appreciated!
Thanks!