hollowstrawberry / kohya-colab

Accessible Google Colab notebooks for Stable Diffusion Lora training, based on the work of kohya-ss and Linaqruf
GNU General Public License v3.0
564 stars 79 forks source link

AttributeError: module 'jax.random' has no attribute 'KeyArray' #128

Closed abujr101 closed 2 months ago

abujr101 commented 2 months ago

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮ │ /content/kohya-trainer/train_network.py:15 in │ │ │ │ 12 from tqdm import tqdm │ │ 13 import torch │ │ 14 from accelerate.utils import set_seed │ │ ❱ 15 from diffusers import DDPMScheduler │ │ 16 │ │ 17 import library.train_util as train_util │ │ 18 from library.train_util import ( │ │ │ │ /usr/local/lib/python3.10/dist-packages/diffusers/init.py:38 in │ │ │ │ 35 │ │ get_polynomial_decay_schedule_with_warmup, │ │ 36 │ │ get_scheduler, │ │ 37 │ ) │ │ ❱ 38 │ from .pipeline_utils import DiffusionPipeline │ │ 39 │ from .pipelines import ( │ │ 40 │ │ DanceDiffusionPipeline, │ │ 41 │ │ DDIMPipeline, │ │ │ │ /usr/local/lib/python3.10/dist-packages/diffusers/pipeline_utils.py:38 in │ │ │ │ 35 from .dynamic_modules_utils import get_class_from_dynamic_module │ │ 36 from .hub_utils import http_user_agent, send_telemetry │ │ 37 from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT │ │ ❱ 38 from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME │ │ 39 from .utils import ( │ │ 40 │ CONFIG_NAME, │ │ 41 │ DIFFUSERS_CACHE, │ │ │ │ /usr/local/lib/python3.10/dist-packages/diffusers/schedulers/init.py:50 in │ │ │ │ 47 │ from ..utils.dummy_flax_objects import * # noqa F403 │ │ 48 else: │ │ 49 │ from .scheduling_ddim_flax import FlaxDDIMScheduler │ │ ❱ 50 │ from .scheduling_ddpm_flax import FlaxDDPMScheduler │ │ 51 │ from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler │ │ 52 │ from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler │ │ 53 │ from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler │ │ │ │ /usr/local/lib/python3.10/dist-packages/diffusers/schedulers/scheduling_ddpm_flax.py:80 in │ │ │ │ │ │ 77 │ state: DDPMSchedulerState │ │ 78 │ │ 79 │ │ ❱ 80 class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): │ │ 81 │ """ │ │ 82 │ Denoising diffusion probabilistic models (DDPMs) explores the connections between de │ │ 83 │ Langevin dynamics sampling. │ │ │ │ /usr/local/lib/python3.10/dist-packages/diffusers/schedulers/scheduling_ddpm_flax.py:216 in │ │ FlaxDDPMScheduler │ │ │ │ 213 │ │ model_output: jnp.ndarray, │ │ 214 │ │ timestep: int, │ │ 215 │ │ sample: jnp.ndarray, │ │ ❱ 216 │ │ key: random.KeyArray, │ │ 217 │ │ return_dict: bool = True, │ │ 218 │ │ **kwargs, │ │ 219 │ ) -> Union[FlaxDDPMSchedulerOutput, Tuple]: │ │ │ │ /usr/local/lib/python3.10/dist-packages/jax/_src/deprecations.py:54 in getattr │ │ │ │ 51 │ │ raise AttributeError(message) │ │ 52 │ warnings.warn(message, DeprecationWarning, stacklevel=2) │ │ 53 │ return fn │ │ ❱ 54 │ raise AttributeError(f"module {module!r} has no attribute {name!r}") │ │ 55 │ │ 56 return getattr │ │ 57 │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯ AttributeError: module 'jax.random' has no attribute 'KeyArray'

please help 😌

Karbadel commented 2 months ago

Same here, please help =(

abujr101 commented 2 months ago

this is literally happening with all the kohya lora training colab out there

gcoox commented 2 months ago

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮ │ /content/kohya-trainer/train_network.py:15 in │ │ │ │ 12 from tqdm import tqdm │ │ 13 import torch │ │ 14 from accelerate.utils import set_seed │ │ ❱ 15 from diffusers import DDPMScheduler │ │ 16 │ │ 17 import library.train_util as train_util │ │ 18 from library.train_util import ( │ │ │ │ /usr/local/lib/python3.10/dist-packages/diffusers/init.py:38 in │ │ │ │ 35 │ │ get_polynomial_decay_schedule_with_warmup, │ │ 36 │ │ get_scheduler, │ │ 37 │ ) │ │ ❱ 38 │ from .pipeline_utils import DiffusionPipeline │ │ 39 │ from .pipelines import ( │ │ 40 │ │ DanceDiffusionPipeline, │ │ 41 │ │ DDIMPipeline, │ │ │ │ /usr/local/lib/python3.10/dist-packages/diffusers/pipeline_utils.py:38 in │ │ │ │ 35 from .dynamic_modules_utils import get_class_from_dynamic_module │ │ 36 from .hub_utils import http_user_agent, send_telemetry │ │ 37 from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT │ │ ❱ 38 from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME │ │ 39 from .utils import ( │ │ 40 │ CONFIG_NAME, │ │ 41 │ DIFFUSERS_CACHE, │ │ │ │ /usr/local/lib/python3.10/dist-packages/diffusers/schedulers/init.py:50 in │ │ │ │ 47 │ from ..utils.dummy_flax_objects import * # noqa F403 │ │ 48 else: │ │ 49 │ from .scheduling_ddim_flax import FlaxDDIMScheduler │ │ ❱ 50 │ from .scheduling_ddpm_flax import FlaxDDPMScheduler │ │ 51 │ from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler │ │ 52 │ from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler │ │ 53 │ from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler │ │ │ │ /usr/local/lib/python3.10/dist-packages/diffusers/schedulers/scheduling_ddpm_flax.py:80 in │ │ │ │ │ │ 77 │ state: DDPMSchedulerState │ │ 78 │ │ 79 │ │ ❱ 80 class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): │ │ 81 │ """ │ │ 82 │ Denoising diffusion probabilistic models (DDPMs) explores the connections between de │ │ 83 │ Langevin dynamics sampling. │ │ │ │ /usr/local/lib/python3.10/dist-packages/diffusers/schedulers/scheduling_ddpm_flax.py:216 in │ │ FlaxDDPMScheduler │ │ │ │ 213 │ │ model_output: jnp.ndarray, │ │ 214 │ │ timestep: int, │ │ 215 │ │ sample: jnp.ndarray, │ │ ❱ 216 │ │ key: random.KeyArray, │ │ 217 │ │ return_dict: bool = True, │ │ 218 │ │ **kwargs, │ │ 219 │ ) -> Union[FlaxDDPMSchedulerOutput, Tuple]: │ │ │ │ /usr/local/lib/python3.10/dist-packages/jax/_src/deprecations.py:54 in getattr │ │ │ │ 51 │ │ raise AttributeError(message) │ │ 52 │ warnings.warn(message, DeprecationWarning, stacklevel=2) │ │ 53 │ return fn │ │ ❱ 54 │ raise AttributeError(f"module {module!r} has no attribute {name!r}") │ │ 55 │ │ 56 return getattr │ │ 57 │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯ AttributeError: module 'jax.random' has no attribute 'KeyArray'

please help 😌

hey friend,i finally settle this problem!try this code maybe can treat>>>>

!pip install timm==0.6.12 fairscale==0.4.13 transformers==4.26.0 requests==2.28.2 accelerate==0.15.0 diffusers[torch]==0.10.2 einops==0.6.0 safetensors==0.2.6 jax==0.4.23 jaxlib==0.4.23

this code must be start before starting the 1.1sections,after finish this code you can using it normally,hopefully can help,

abujr101 commented 2 months ago

wow. that works. thank you.

hollowstrawberry commented 2 months ago

This was already fixed, please use the latest trainer.

https://colab.research.google.com/github/hollowstrawberry/kohya-colab/blob/main/Lora_Trainer.ipynb

Karbadel commented 2 months ago

Thanks a lot! =)

InPantsPro commented 2 months ago

Was it fixed on your XL trainer as well? I still receive the error from the Colab I have saved in my drive.

hollowstrawberry commented 2 months ago

Delete the saved copy and make a new one from the original. Or just use the original and don't make a copy, whichever you prefer. All 3 colabs were fixed at the same time yesterday.