p1atdev / LECO

Low-rank adaptation for Erasing COncepts from diffusion models.
https://arxiv.org/abs/2303.07345
Apache License 2.0
307 stars 23 forks source link

SDXL training bug: AttributeError: module 'jax.random' has no attribute 'KeyArray' #40

Open Chadius opened 4 months ago

Chadius commented 4 months ago

As I mentioned in https://github.com/p1atdev/LECO/issues/39, the jax library has deprecated jax.random.keyArray. It was removed entirely in 0.4.24, so LECO has a broken dependency.

/usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py:311: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  torch.utils._pytree._register_pytree_node(
2024-04-14 00:57:27.323177: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-14 00:57:27.323222: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-14 00:57:27.324499: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-04-14 00:57:28.392824: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
/usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py:311: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  torch.utils._pytree._register_pytree_node(
Traceback (most recent call last):
  File "/content/LECO/train_lora_xl.py", line 15, in <module>
    from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
  File "/content/LECO/lora.py", line 11, in <module>
    from diffusers import UNet2DConditionModel
  File "/usr/local/lib/python3.10/dist-packages/diffusers/__init__.py", line 38, in <module>
    from .models import (
  File "/usr/local/lib/python3.10/dist-packages/diffusers/models/__init__.py", line 36, in <module>
    from .controlnet_flax import FlaxControlNetModel
  File "/usr/local/lib/python3.10/dist-packages/diffusers/models/controlnet_flax.py", line 25, in <module>
    from .modeling_flax_utils import FlaxModelMixin
  File "/usr/local/lib/python3.10/dist-packages/diffusers/models/modeling_flax_utils.py", line 46, in <module>
    class FlaxModelMixin(PushToHubMixin):
  File "/usr/local/lib/python3.10/dist-packages/diffusers/models/modeling_flax_utils.py", line 195, in FlaxModelMixin
    def init_weights(self, rng: jax.random.KeyArray) -> Dict:
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/deprecations.py", line 54, in getattr
    raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.random' has no attribute 'KeyArray'```

I tried downgrading jax to 0.4.23 but now it talks about a circular import and fails to run.

```bash
/usr/local/lib/python3.10/dist-packages/torch/distributed/_functional_collectives.py:28: UserWarning: Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly
  warnings.warn(
Traceback (most recent call last):
  File "/content/LECO/train_lora_xl.py", line 15, in <module>
    from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
  File "/content/LECO/lora.py", line 11, in <module>
    from diffusers import UNet2DConditionModel
  File "/usr/local/lib/python3.10/dist-packages/diffusers/__init__.py", line 3, in <module>
    from .configuration_utils import ConfigMixin
  File "/usr/local/lib/python3.10/dist-packages/diffusers/configuration_utils.py", line 34, in <module>
    from .utils import (
  File "/usr/local/lib/python3.10/dist-packages/diffusers/utils/__init__.py", line 21, in <module>
    from .accelerate_utils import apply_forward_hook
  File "/usr/local/lib/python3.10/dist-packages/diffusers/utils/accelerate_utils.py", line 24, in <module>
    import accelerate
  File "/usr/local/lib/python3.10/dist-packages/accelerate/__init__.py", line 3, in <module>
    from .accelerator import Accelerator
  File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 35, in <module>
    from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
  File "/usr/local/lib/python3.10/dist-packages/accelerate/checkpointing.py", line 24, in <module>
    from .utils import (
  File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/__init__.py", line 135, in <module>
    from .fsdp_utils import load_fsdp_model, load_fsdp_optimizer, save_fsdp_model, save_fsdp_optimizer
  File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/fsdp_utils.py", line 25, in <module>
    import torch.distributed.checkpoint as dist_cp
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/checkpoint/__init__.py", line 7, in <module>
    from .state_dict_loader import load_state_dict, load
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/checkpoint/state_dict_loader.py", line 12, in <module>
    from .default_planner import DefaultLoadPlanner
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/checkpoint/default_planner.py", line 14, in <module>
    from torch.distributed._tensor import DTensor
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_tensor/__init__.py", line 346, in <module>
    import torch.distributed._tensor._dynamo_utils
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_tensor/_dynamo_utils.py", line 1, in <module>
    from torch._dynamo import allow_in_graph
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/__init__.py", line 2, in <module>
    from . import allowed_functions, convert_frame, eval_frame, resume_execution
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 62, in <module>
    from .output_graph import OutputGraph
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 39, in <module>
    from . import config, logging as torchdynamo_logging, variables
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/__init__.py", line 26, in <module>
    from .higher_order_ops import TorchHigherOrderOperatorVariable
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/higher_order_ops.py", line 11, in <module>
    import torch.onnx.operators
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/__init__.py", line 46, in <module>
    from ._internal.exporter import (  # usort:skip. needs to be last to avoid circular import
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/exporter.py", line 44, in <module>
    from torch.onnx._internal.fx import (
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/__init__.py", line 1, in <module>
    from .patcher import ONNXTorchPatcher
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/patcher.py", line 11, in <module>
    import transformers  # type: ignore[import]
  File "/usr/local/lib/python3.10/dist-packages/transformers/__init__.py", line 26, in <module>
    from . import dependency_versions_check
  File "/usr/local/lib/python3.10/dist-packages/transformers/dependency_versions_check.py", line 16, in <module>
    from .utils.versions import require_version, require_version_core
  File "/usr/local/lib/python3.10/dist-packages/transformers/utils/__init__.py", line 31, in <module>
    from .generic import (
  File "/usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py", line 33, in <module>
    import jax.numpy as jnp
  File "/usr/local/lib/python3.10/dist-packages/jax/__init__.py", line 39, in <module>
    from jax import config as _config_module
  File "/usr/local/lib/python3.10/dist-packages/jax/config.py", line 15, in <module>
    from jax._src.config import config as _deprecated_config  # noqa: F401
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/config.py", line 28, in <module>
    from jax._src import lib
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/lib/__init__.py", line 75, in <module>
    jax_version=jax.version.__version__,
AttributeError: partially initialized module 'jax' has no attribute 'version' (most likely due to a circular import)