/usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py:311: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
torch.utils._pytree._register_pytree_node(
2024-04-14 00:57:27.323177: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-14 00:57:27.323222: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-14 00:57:27.324499: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-04-14 00:57:28.392824: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
/usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py:311: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
torch.utils._pytree._register_pytree_node(
Traceback (most recent call last):
File "/content/LECO/train_lora_xl.py", line 15, in <module>
from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
File "/content/LECO/lora.py", line 11, in <module>
from diffusers import UNet2DConditionModel
File "/usr/local/lib/python3.10/dist-packages/diffusers/__init__.py", line 38, in <module>
from .models import (
File "/usr/local/lib/python3.10/dist-packages/diffusers/models/__init__.py", line 36, in <module>
from .controlnet_flax import FlaxControlNetModel
File "/usr/local/lib/python3.10/dist-packages/diffusers/models/controlnet_flax.py", line 25, in <module>
from .modeling_flax_utils import FlaxModelMixin
File "/usr/local/lib/python3.10/dist-packages/diffusers/models/modeling_flax_utils.py", line 46, in <module>
class FlaxModelMixin(PushToHubMixin):
File "/usr/local/lib/python3.10/dist-packages/diffusers/models/modeling_flax_utils.py", line 195, in FlaxModelMixin
def init_weights(self, rng: jax.random.KeyArray) -> Dict:
File "/usr/local/lib/python3.10/dist-packages/jax/_src/deprecations.py", line 54, in getattr
raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.random' has no attribute 'KeyArray'```
I tried downgrading jax to 0.4.23 but now it talks about a circular import and fails to run.
```bash
/usr/local/lib/python3.10/dist-packages/torch/distributed/_functional_collectives.py:28: UserWarning: Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly
warnings.warn(
Traceback (most recent call last):
File "/content/LECO/train_lora_xl.py", line 15, in <module>
from lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
File "/content/LECO/lora.py", line 11, in <module>
from diffusers import UNet2DConditionModel
File "/usr/local/lib/python3.10/dist-packages/diffusers/__init__.py", line 3, in <module>
from .configuration_utils import ConfigMixin
File "/usr/local/lib/python3.10/dist-packages/diffusers/configuration_utils.py", line 34, in <module>
from .utils import (
File "/usr/local/lib/python3.10/dist-packages/diffusers/utils/__init__.py", line 21, in <module>
from .accelerate_utils import apply_forward_hook
File "/usr/local/lib/python3.10/dist-packages/diffusers/utils/accelerate_utils.py", line 24, in <module>
import accelerate
File "/usr/local/lib/python3.10/dist-packages/accelerate/__init__.py", line 3, in <module>
from .accelerator import Accelerator
File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 35, in <module>
from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
File "/usr/local/lib/python3.10/dist-packages/accelerate/checkpointing.py", line 24, in <module>
from .utils import (
File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/__init__.py", line 135, in <module>
from .fsdp_utils import load_fsdp_model, load_fsdp_optimizer, save_fsdp_model, save_fsdp_optimizer
File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/fsdp_utils.py", line 25, in <module>
import torch.distributed.checkpoint as dist_cp
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/checkpoint/__init__.py", line 7, in <module>
from .state_dict_loader import load_state_dict, load
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/checkpoint/state_dict_loader.py", line 12, in <module>
from .default_planner import DefaultLoadPlanner
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/checkpoint/default_planner.py", line 14, in <module>
from torch.distributed._tensor import DTensor
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_tensor/__init__.py", line 346, in <module>
import torch.distributed._tensor._dynamo_utils
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_tensor/_dynamo_utils.py", line 1, in <module>
from torch._dynamo import allow_in_graph
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/__init__.py", line 2, in <module>
from . import allowed_functions, convert_frame, eval_frame, resume_execution
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 62, in <module>
from .output_graph import OutputGraph
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 39, in <module>
from . import config, logging as torchdynamo_logging, variables
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/__init__.py", line 26, in <module>
from .higher_order_ops import TorchHigherOrderOperatorVariable
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/higher_order_ops.py", line 11, in <module>
import torch.onnx.operators
File "/usr/local/lib/python3.10/dist-packages/torch/onnx/__init__.py", line 46, in <module>
from ._internal.exporter import ( # usort:skip. needs to be last to avoid circular import
File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/exporter.py", line 44, in <module>
from torch.onnx._internal.fx import (
File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/__init__.py", line 1, in <module>
from .patcher import ONNXTorchPatcher
File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/patcher.py", line 11, in <module>
import transformers # type: ignore[import]
File "/usr/local/lib/python3.10/dist-packages/transformers/__init__.py", line 26, in <module>
from . import dependency_versions_check
File "/usr/local/lib/python3.10/dist-packages/transformers/dependency_versions_check.py", line 16, in <module>
from .utils.versions import require_version, require_version_core
File "/usr/local/lib/python3.10/dist-packages/transformers/utils/__init__.py", line 31, in <module>
from .generic import (
File "/usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py", line 33, in <module>
import jax.numpy as jnp
File "/usr/local/lib/python3.10/dist-packages/jax/__init__.py", line 39, in <module>
from jax import config as _config_module
File "/usr/local/lib/python3.10/dist-packages/jax/config.py", line 15, in <module>
from jax._src.config import config as _deprecated_config # noqa: F401
File "/usr/local/lib/python3.10/dist-packages/jax/_src/config.py", line 28, in <module>
from jax._src import lib
File "/usr/local/lib/python3.10/dist-packages/jax/_src/lib/__init__.py", line 75, in <module>
jax_version=jax.version.__version__,
AttributeError: partially initialized module 'jax' has no attribute 'version' (most likely due to a circular import)
As I mentioned in https://github.com/p1atdev/LECO/issues/39, the jax library has deprecated
jax.random.keyArray
. It was removed entirely in 0.4.24, so LECO has a broken dependency.