hollowstrawberry / kohya-colab

Accessible Google Colab notebooks for Stable Diffusion Lora training, based on the work of kohya-ss and Linaqruf
GNU General Public License v3.0
617 stars 93 forks source link

getting this error #120

Closed katywilliams1121 closed 7 months ago

katywilliams1121 commented 7 months ago


env: PYTHONPATH=/content/kohya-trainer
2024-04-08 17:44:27.906022: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-08 17:44:27.906075: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-08 17:44:27.907510: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /content/kohya-trainer/finetune/make_captions.py:17 in <module>                                  │
│                                                                                                  │
│    14 from torchvision.transforms.functional import InterpolationMode                            │
│    15 sys.path.append(os.path.dirname(__file__))                                                 │
│    16 from blip.blip import blip_decoder                                                         │
│ ❱  17 import library.train_util as train_util                                                    │
│    18                                                                                            │
│    19 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")                      │
│    20                                                                                            │
│                                                                                                  │
│ /content/kohya-trainer/library/train_util.py:39 in <module>                                      │
│                                                                                                  │
│     36 from torchvision import transforms                                                        │
│     37 from transformers import CLIPTokenizer                                                    │
│     38 import transformers                                                                       │
│ ❱   39 import diffusers                                                                          │
│     40 from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION              │
│     41 from diffusers import (                                                                   │
│     42 │   StableDiffusionPipeline,                                                              │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/diffusers/__init__.py:38 in <module>                     │
│                                                                                                  │
│    35 │   │   get_polynomial_decay_schedule_with_warmup,                                         │
│    36 │   │   get_scheduler,                                                                     │
│    37 │   )                                                                                      │
│ ❱  38 │   from .pipeline_utils import DiffusionPipeline                                          │
│    39 │   from .pipelines import (                                                               │
│    40 │   │   DanceDiffusionPipeline,                                                            │
│    41 │   │   DDIMPipeline,                                                                      │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/diffusers/pipeline_utils.py:38 in <module>               │
│                                                                                                  │
│    35 from .dynamic_modules_utils import get_class_from_dynamic_module                           │
│    36 from .hub_utils import http_user_agent, send_telemetry                                     │
│    37 from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT                                     │
│ ❱  38 from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME                             │
│    39 from .utils import (                                                                       │
│    40 │   CONFIG_NAME,                                                                           │
│    41 │   DIFFUSERS_CACHE,                                                                       │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/diffusers/schedulers/__init__.py:50 in <module>          │
│                                                                                                  │
│   47 │   from ..utils.dummy_flax_objects import *  # noqa F403                                   │
│   48 else:                                                                                       │
│   49 │   from .scheduling_ddim_flax import FlaxDDIMScheduler                                     │
│ ❱ 50 │   from .scheduling_ddpm_flax import FlaxDDPMScheduler                                     │
│   51 │   from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler        │
│   52 │   from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler                            │
│   53 │   from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler                      │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/diffusers/schedulers/scheduling_ddpm_flax.py:80 in       │
│ <module>                                                                                         │
│                                                                                                  │
│    77 │   state: DDPMSchedulerState                                                              │
│    78                                                                                            │
│    79                                                                                            │
│ ❱  80 class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):                                  │
│    81 │   """                                                                                    │
│    82 │   Denoising diffusion probabilistic models (DDPMs) explores the connections between de   │
│    83 │   Langevin dynamics sampling.                                                            │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/diffusers/schedulers/scheduling_ddpm_flax.py:216 in      │
│ FlaxDDPMScheduler                                                                                │
│                                                                                                  │
│   213 │   │   model_output: jnp.ndarray,                                                         │
│   214 │   │   timestep: int,                                                                     │
│   215 │   │   sample: jnp.ndarray,                                                               │
│ ❱ 216 │   │   key: random.KeyArray,                                                              │
│   217 │   │   return_dict: bool = True,                                                          │
│   218 │   │   **kwargs,                                                                          │
│   219 │   ) -> Union[FlaxDDPMSchedulerOutput, Tuple]:                                            │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/jax/_src/deprecations.py:54 in getattr                   │
│                                                                                                  │
│   51 │   │   raise AttributeError(message)                                                       │
│   52 │     warnings.warn(message, DeprecationWarning, stacklevel=2)                              │
│   53 │     return fn                                                                             │
│ ❱ 54 │   raise AttributeError(f"module {module!r} has no attribute {name!r}")                    │
│   55                                                                                             │
│   56   return getattr                                                                            │
│   57                                                                                             │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AttributeError: module 'jax.random' has no attribute 'KeyArray```
EbisusBaka commented 7 months ago

Same here, tried the other notebooks and it's the same error. Seems some change happened to Google Colab default environment.

Or some dependency bumped jax - jax 0.4.23 was latest version that had KeyArray. When I run !pip show jax, I get 0.4.26

DASDAWDDWADSADSA commented 7 months ago

same here

DeFek1 commented 7 months ago

Yup, even on the other notebook

gaco123 commented 7 months ago

Try to use this code line in google colab notebooks to fix the issue: !pip install jax==0.4.23 jaxlib==0.4.23 -f https://storage.googleapis.com/jax-releases/jax_releases.html

It works for me but its temporal maybe it's neccesary change other things to fix permantly xd

propergurl commented 7 months ago

Please Help the fix xformer version and jax version Still no luck using this into notebook :

!pip install jax==0.4.23 jaxlib==0.4.23 -f
!pip install -q xformers==0.0.26.dev777

-- Update, I found the fix version

!pip install jax==0.4.23 jaxlib==0.4.23 
!pip install xformers==0.0.24 --no-deps

Hope this helps.

hollowstrawberry commented 7 months ago

Should be fixed. I also changed the way errors are shown so that they're not ridiculously long.

byungtaekyu commented 7 months ago

thank you!!