ShivamShrirao / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch
https://huggingface.co/docs/diffusers
Apache License 2.0
1.89k stars 509 forks source link

Training process fails with a Jax library related issue #255

Open randheerDas opened 3 months ago

randheerDas commented 3 months ago

Describe the bug

Training process fails with a Jax library related issue.

This the the python code in the notebook cell, that fails:

!python3 train_dreambooth.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --pretrained_vae_name_or_path="stabilityai/sd-vae-ft-mse" \ --output_dir=$OUTPUT_DIR \ --with_prior_preservation --prior_loss_weight=1.0 \ --seed=1337 \ --resolution=512 \ --train_batch_size=1 \ --train_text_encoder \ --mixed_precision="fp16" \ --use_8bit_adam \ --gradient_accumulation_steps=1 \ --learning_rate=1e-6 \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --num_class_images=50 \ --sample_batch_size=4 \ --max_train_steps=800 \ --save_interval=10000 \ --save_sample_prompt="photo of narrow gate" \ --concepts_list="concepts_list.json"

Attached is the screenshot for the error:

Error

Reproduction

Run the training process by issuing the following command:

!python3 train_dreambooth.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --pretrained_vae_name_or_path="stabilityai/sd-vae-ft-mse" \ --output_dir=$OUTPUT_DIR \ --with_prior_preservation --prior_loss_weight=1.0 \ --seed=1337 \ --resolution=512 \ --train_batch_size=1 \ --train_text_encoder \ --mixed_precision="fp16" \ --use_8bit_adam \ --gradient_accumulation_steps=1 \ --learning_rate=1e-6 \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --num_class_images=50 \ --sample_batch_size=4 \ --max_train_steps=800 \ --save_interval=10000 \ --save_sample_prompt="photo of narrow gate" \ --concepts_list="concepts_list.json"

Logs

No response

System Info

I am running this on a google colab runtime on a python 3 running on a Google compute engine with a Tesla GPU.

Install details:

!wget -q https://github.com/ShivamShrirao/diffusers/raw/main/examples/dreambooth/train_dreambooth.py !wget -q https://github.com/ShivamShrirao/diffusers/raw/main/scripts/convert_diffusers_to_original_stable_diffusion.py %pip install -qq git+https://github.com/ShivamShrirao/diffusers %pip install -q -U --pre triton %pip install -q accelerate transformers ftfy bitsandbytes==0.35.0 gradio natsort safetensors xformers

mahaboobkhan29 commented 3 months ago

Any Update ? facing same issue

The-Ramosian commented 3 months ago

this seems to work:

!pip install "jax[cuda12_local]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

JossCamp commented 2 months ago

Google always ends up ruining something in each update, you need to use a specific version: !pip install jax==0.4.19 jaxlib==0.4.19 -f https://storage.googleapis.com/jax-releases/jax_releases.html

This solves the problem for now.

mirodil-ml commented 2 months ago

Indeed !pip install jax==0.4.19 jaxlib==0.4.19 -f https://storage.googleapis.com/jax-releases/jax_releases.html is solving this issue, but there is another issue comes RuntimeError: operator torchvision::nms does not exist:

Traceback (most recent call last):
  File "/content/train_dreambooth.py", line 26, in <module>
    from torchvision import transforms
  File "/usr/local/lib/python3.10/dist-packages/torchvision/__init__.py", line 6, in <module>
    from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils
  File "/usr/local/lib/python3.10/dist-packages/torchvision/_meta_registrations.py", line 164, in <module>
    def meta_nms(dets, scores, iou_threshold):
  File "/usr/local/lib/python3.10/dist-packages/torch/library.py", line 467, in inner
    handle = entry.abstract_impl.register(func_to_register, source)
  File "/usr/local/lib/python3.10/dist-packages/torch/_library/abstract_impl.py", line 30, in register
    if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"):
RuntimeError: operator torchvision::nms does not exist

probably the PyTorch version should be fixed too, but which version?

roman19932024 commented 2 months ago

Indeed !pip install jax==0.4.19 jaxlib==0.4.19 -f https://storage.googleapis.com/jax-releases/jax_releases.html is solving this issue, but there is another issue comes RuntimeError: operator torchvision::nms does not exist:

Traceback (most recent call last):
  File "/content/train_dreambooth.py", line 26, in <module>
    from torchvision import transforms
  File "/usr/local/lib/python3.10/dist-packages/torchvision/__init__.py", line 6, in <module>
    from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils
  File "/usr/local/lib/python3.10/dist-packages/torchvision/_meta_registrations.py", line 164, in <module>
    def meta_nms(dets, scores, iou_threshold):
  File "/usr/local/lib/python3.10/dist-packages/torch/library.py", line 467, in inner
    handle = entry.abstract_impl.register(func_to_register, source)
  File "/usr/local/lib/python3.10/dist-packages/torch/_library/abstract_impl.py", line 30, in register
    if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"):
RuntimeError: operator torchvision::nms does not exist

probably the PyTorch version should be fixed too, but which version?

I have the same problem. Did anyone find a solution?

mirodil-ml commented 2 months ago

@roman19932024 try to update python version to 3.10.