Run it and you will see error while running !python3 train_dreambooth.py .... step
Logs
Traceback (most recent call last):
File "/content/train_dreambooth.py", line 21, in <module>
from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
File "/usr/local/lib/python3.10/dist-packages/diffusers/__init__.py", line 36, in <module>
from .models import (
File "/usr/local/lib/python3.10/dist-packages/diffusers/models/__init__.py", line 33, in <module>
from .controlnet_flax import FlaxControlNetModel
File "/usr/local/lib/python3.10/dist-packages/diffusers/models/controlnet_flax.py", line 25, in <module>
from .modeling_flax_utils import FlaxModelMixin
File "/usr/local/lib/python3.10/dist-packages/diffusers/models/modeling_flax_utils.py", line 45, in <module>
class FlaxModelMixin:
File "/usr/local/lib/python3.10/dist-packages/diffusers/models/modeling_flax_utils.py", line 192, in FlaxModelMixin
def init_weights(self, rng: jax.random.KeyArray) -> Dict:
File "/usr/local/lib/python3.10/dist-packages/jax/_src/deprecations.py", line 54, in getattr
raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.random' has no attribute 'KeyArray'
Describe the bug
Running
train_dreambooth.py
on Google Colab throws the following errorReproduction
!python3 train_dreambooth.py ....
stepLogs