ShivamShrirao / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch
https://huggingface.co/docs/diffusers
Apache License 2.0
1.89k stars 509 forks source link

AttributeError: module 'jax.random' has no attribute 'KeyArray' while running DreamBooth_Stable_Diffusion.ipynb #257

Open mirodil-ml opened 2 months ago

mirodil-ml commented 2 months ago

Describe the bug

Running train_dreambooth.py on Google Colab throws the following error

AttributeError: module 'jax.random' has no attribute 'KeyArray'

Reproduction

Logs

Traceback (most recent call last):
  File "/content/train_dreambooth.py", line 21, in <module>
    from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
  File "/usr/local/lib/python3.10/dist-packages/diffusers/__init__.py", line 36, in <module>
    from .models import (
  File "/usr/local/lib/python3.10/dist-packages/diffusers/models/__init__.py", line 33, in <module>
    from .controlnet_flax import FlaxControlNetModel
  File "/usr/local/lib/python3.10/dist-packages/diffusers/models/controlnet_flax.py", line 25, in <module>
    from .modeling_flax_utils import FlaxModelMixin
  File "/usr/local/lib/python3.10/dist-packages/diffusers/models/modeling_flax_utils.py", line 45, in <module>
    class FlaxModelMixin:
  File "/usr/local/lib/python3.10/dist-packages/diffusers/models/modeling_flax_utils.py", line 192, in FlaxModelMixin
    def init_weights(self, rng: jax.random.KeyArray) -> Dict:
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/deprecations.py", line 54, in getattr
    raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.random' has no attribute 'KeyArray'


### System Info

Latest google colab
cimplesid commented 2 months ago

pip install --upgrade diffusers Use this @mirodil-ml