google-deepmind / jax_privacy

Algorithms for Privacy-Preserving Machine Learning in JAX
Apache License 2.0
87 stars 11 forks source link

Checkpoint error in training WRN-16-4 on CIFAR-10 #4

Closed CHAOS-Yang closed 2 years ago

CHAOS-Yang commented 2 years ago

I I want to reproduce the experiment of training WRN-16-4 on CIFAR-10 according to the README in experiments/image_classification. I successfully configured the environment and got the code running. I use the command that Training from Scratch on CIFAR-10 in README. However, the checkpoint error appears after a few thousand steps. It always repeats the output to the console as:

Checkpoint 6 invalid or already evaluated, waiting.

It will repeat every 10s, and stop training. I checked that this was a problem in jaxline, but checked that the memory limit was not exceeded, which puzzled me.

I experimented on a server with 4 NVIDIA 3090 24G and 256G memory.

Thanks in advance!

lberrada commented 2 years ago

Hi, this is not an actual error, but just a log from the evaluation code, stating that it is waiting for a new checkpoint to be produced by the training code so that it can evaluate it.

In more detail:

CHAOS-Yang commented 2 years ago

Thank you very much for your reply. I confirmed that the code is running, the training part is stopped, but the evaluation part is still running. I found that the loop lasted almost 6 hours, but no new checkpoints were generated. Have you faced a similar problem or do you have some suggestion for a solution.

lberrada commented 2 years ago

We have not encountered this issue. Do you have some logs to share so that I can see if there's any useful information there?

CHAOS-Yang commented 2 years ago

For example, it printed: I0628 19:56:09.069872 140008971294464 train.py:38] global_step: 1000, {'acc1': 40.234375, 'acc5': 86.328125, 'batch_size': 4096, 'data_seen': 255744, 'grad_norms_before_clipping_max': 18.18052101135254, 'grad_norms_before_clipping_mean': 10.37612247467041, 'grad_norms_before_clipping_median': 10.249549865722656, 'grad_norms_before_clipping_min': 3.8731744289398193, 'grad_norms_before_clipping_std': 3.1992735862731934, 'grads_clipped': 1.0, 'grads_norm': 4.862388610839844, 'l2_loss': 4011.541748046875, 'learning_rate': 4.0, 'noise_std': 0.0029296875, 'reg': 0.0, 'snr_global': 0.015546687878668308, 'steps_per_sec': 1.6854322570659348, 'train_loss': 1.7164676189422607, 'train_loss_max': 4.354676723480225, 'train_loss_mean': 1.7164676189422607, 'train_loss_median': 1.595414638519287, 'train_loss_min': 0.06957746297121048, 'train_loss_std': 0.778220534324646, 'train_obj': 1.7164676189422607, 'update_every': 16, 'update_step': 62} Then 30min without any other output, but the program is still running. We check the GPU and CPU. And it shows that the program still takes up memory and GPU memory. However, the GPU-Util becomes zero.

CHAOS-Yang commented 2 years ago

I think I found the problem, tensorflow shows that there is no available GPUs, maybe there is a problem with the cuda. Thank you very much for your previous helps.

CHAOS-Yang commented 2 years ago

May I know what CUDA version the code is running on and the corresponding cuDNN and cudatoolkit versions?

lberrada commented 2 years ago

We have been able (at some point) to run our open-sourced code with CUDA 11.0 and CuDNN 11.1.

At that time, we also noted that using the environment variable TF_FORCE_GPU_ALLOW_GROWTH='true' python run_experiment.py ... can be helpful: it avoids the behaviour where TensorFlow reserves all GPU memory and leaves none for JAX.

In case that's useful, there are more details about installing JAX for GPUs in this JAX ReadMe section.

CHAOS-Yang commented 2 years ago

Thank you very much for your reply, it helped me a lot. Following the tips, I made adjustments in the code, which allowed me to successfully reproduce the experiments. I add the following code in run_experiment.py: from tensorflow.compat.v1 import ConfigProto from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto() config.gpu_options.allow_growth = True session = InteractiveSession(config=config) And it works. Thank you very much.