ssnl / dataset-distillation

Open-source code for paper "Dataset Distillation"
https://ssnl.github.io/dataset_distillation
MIT License
778 stars 115 forks source link

RuntimeError: CUDA out of memory. #53

Closed data-science-lover closed 2 years ago

data-science-lover commented 2 years ago

Hi, I have the following error when using the GPU on my own dataset (2 classes) and my own model:

"RuntimeError: CUDA out of memory. Tried to allocate 294.00 MiB (GPU 0; 11.17 GiB total capacity; 10.47 GiB already allocated; 107.25 MiB free; 10.65 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"

Following what you explained at this link : https://github.com/SsnL/dataset-distillation/issues/28 , I tried different combinations of distill_steps, distill_epochs, distilled_images_per_class_per_step and num_distill_classes. I realized that the GPU limit was reached when: num_distill_classes epochs images_per_class_per_step * distill_steps > 4 .

The problem is that with 2 epochs, 2 steps, 1 image and 1 distilled class, the results are not sufficient.

What can I do to improve them?

ps : I use a tesla k80 (12 gb of dedicated memory) with 56 gb of RAM.

Thank you in advance

data-science-lover commented 2 years ago

Also, you said to decrease the steps and the number of images per class but the number of epochs also seems to influence the number of images obtained, right?

For example, for distill_epochs 5 --distill_steps 10 --num_distill_classes 1 --distilled_images_per_class_per_step 1, we should have 50 images.

So I was wondering if we should keep only the results of the last step or keep all the images?

ssnl commented 2 years ago

the number of epochs also seems to influence the number of images obtained, right?

the number of epochs does not affect the number of images distilled. but yes it increases memory usage.

ssnl commented 2 years ago

to reduce memory usage without sacrificing #images #steps or #epochs, you can try either implementing gradient checkpointing or use distributed training.

data-science-lover commented 2 years ago

the number of epochs does not affect the number of images distilled. but yes it increases memory usage.

Thank you for this very quick answer.

So it's not normal that with 5 epochs of distillation and 10 steps, I get 50 images? (with --num_distill_classes 1 --distilled_images_per_class_per_step 1) image

data-science-lover commented 2 years ago

to reduce memory usage without sacrificing #images #steps or #epochs, you can try either implementing gradient checkpointing or use distributed training.

I thank you for the advice! I will try one of these methods and keep you posted on the progress.

After further analysis, I noticed that the gpu limit problem appeared in train_distilled_images.py at the line: output = model.forward_with_param(rdata, params[-1])

it goes into this function but it hits the limit before it comes out.

data-science-lover commented 2 years ago

to reduce memory usage without sacrificing #images #steps or #epochs, you can try either implementing gradient checkpointing or use distributed training.

So I tried to implement gradient checkpointing. It works for network training but not for data distillation. The problem seems to come from an incompatibility between the torch.utils.checkpoint library and loss.backward() in basic.py (which uses torch.autograd.backward) image

image

image

ssnl commented 2 years ago

The pytorch one won’t work for the reasons you described. You would need to implement it manually.

On Thu, Jun 16, 2022 at 10:20 tkt @.***> wrote:

to reduce memory usage without sacrificing #images #steps or #epochs, you can try either implementing gradient checkpointing or use distributed training.

So I tried to implement gradient checkpointing. It works for network training but not for data distillation. The problem seems to come from an incompatibility between the torch.utils.checkpoint library and loss.backward() in basic.py (which uses torch.autograd.backward) [image: image] https://user-images.githubusercontent.com/56973760/174040781-098d05ed-3043-4810-9815-55f806659e66.png

[image: image] https://user-images.githubusercontent.com/56973760/174040883-0f082030-0aa2-419d-b6f2-0698090d0720.png

[image: image] https://user-images.githubusercontent.com/56973760/174090823-3ab6468c-0d31-45eb-8b92-26e7ef2b4d9f.png

— Reply to this email directly, view it on GitHub https://github.com/SsnL/dataset-distillation/issues/53#issuecomment-1157717390, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABLJMZND3FTHPRKJKPE7CNTVPMZ3BANCNFSM5YLK7LRA . You are receiving this because you commented.Message ID: @.***>