google-deepmind / jax_privacy

Algorithms for Privacy-Preserving Machine Learning in JAX
Apache License 2.0
87 stars 11 forks source link

JAX issue: Freezing the python requirements #5

Closed RoyRin closed 1 year ago

RoyRin commented 2 years ago

Hi, I am trying to recreate your results, and there is an issue with JAX using up 100% of a single GPU and freezing on the first training step. image

I think part of the issue is that I am using a different version of JAX than jax_privacy expects. Can you please share the exact versions of the python dependencies. Relatedly, it seems like you are using Cuda-11.0 is that correct?

Note: the specific issue arises at this line: https://github.com/deepmind/jax_privacy/blob/main/jax_privacy/src/training/image_classification/updater.py#L213 when it calls _pmapped_update, and tries to run jax.pmap

Additional Edits: jax.distributed.initialize is never called in the codebase, which makes me believe that this code base runs on Jax <0.3.0, based off of this changelog here : https://jax.readthedocs.io/en/latest/changelog.html#jax-0-2-25-nov-10-2021 .

lberrada commented 2 years ago

Hi, we were able to run this code relatively recently with this requirements file https://github.com/deepmind/jax_privacy/blob/dev/requirements.txt, can you try using that when installing the dependencies? The main change compared to the main version of the requirements is that we specify jax==0.3.5 (otherwise there can be issues with jaxlib, that jax itself depends on).

That's correct for CUDA, we used CUDA 11.0 and CuDNN 11.1.

RoyRin commented 2 years ago

I tried this, but I am now trying again. If possible, would you mind posting the exact output of pip freeze > requirements.txt, so that I can check against the exact same dependencies?

lberrada commented 2 years ago

Thinking more about this, I believe that the dependencies discrepancy might be a red herring since JAX seems to be set up well enough to at least allocate memory on GPU.

The fact that JAX is using almost all memory and freezes at the first step rather suggests that it could be struggling to compile the XLA program while fitting everything in memory. When the memory requirements are high, the JAX compiler (which is implicitly called by pmap) looks for ways to fit everything in memory, which can be slow sometimes.

To test whether this is indeed the issue at hand, could you (1) share your experiment config, and (2) try running an experiment with a small model and a small batch-size per device per step, so that we can see whether reducing the memory requirements works better?

RoyRin commented 2 years ago

That seems reasonable, I agree. I was using the config provided, https://github.com/deepmind/jax_privacy/blob/main/experiments/image_classification/configs/cifar10_wrn_16_4_eps1.py which does have

batch_size=dict(
                      init_value=4096,
                      per_device_per_step=64,
                      scale_schedule=None,  # example: {'2000': 8, '4000': 16},
                  ),

I'll try setting the batch_size to be smaller, and report back.

RoyRin commented 1 year ago

I thought this was resolved. But I ran into some more OOM warnings. Now that I'm tracking the jax-privacy code, it looks at each step, the entire dataset is loaded onto the GPU, and utilization hits 100%.

This has now changed in theme from the original issue, so I can open a different issue, if need be. But the direct question is:

Is this warning expected behaviour?:

6 dataset_info.py:491] Load dataset info from /h/royrin/tensorflow_datasets/cifar10/3.0.2
I1006 15:49:31.985438 140647251464896 dataset_info.py:550] Field info.splits from disk and from code do not match. Keeping the one from code.
I1006 15:49:31.985566 140647251464896 dataset_info.py:550] Field info.supervised_keys from disk and from code do not match. Keeping the one from code.
I1006 15:49:31.985766 140647251464896 dataset_builder.py:383] Reusing dataset cifar10 (/h/royrin/tensorflow_datasets/cifar10/3.0.2)
I1006 15:49:31.985876 140647251464896 logging_logger.py:44] Constructing tf.data.Dataset cifar10 for split test, from /h/royrin/tensorflow_datasets/cifar10/3.0.2

These warnings don't seem related, but it's possible they are somehow cause the entire dataset to get loaded onto the GPU each step. So I want to check.

lberrada commented 1 year ago

Yes, this is expected, this are standard logs that tensorflow_datasets outputs when loading data.

Our code should not load the entire dataset on the device at each iteration, unless the batch-size per step per device is specified to do so. At each iteration, we are expecting images of shape [num_devices, batch_size_per_step_per_device, augmult, *image_dimensions].

RoyRin commented 1 year ago

Okay thank you, I will close this issue. I am still getting memory issues some times, but I believe it has to do with the machine, and not the code. If it comes up again, I will re-open it.

Thank you for your help!