google / prompt-to-prompt

Apache License 2.0
2.98k stars 279 forks source link

Bugs about 'from diffusers import StableDiffusionPipeline' #40

Open studying910 opened 1 year ago

studying910 commented 1 year ago

I have installed the required diffusers and transformers, but occurs:


TypeError Traceback (most recent call last)

in 1 from typing import Optional, Union, Tuple, List, Callable, Dict 2 import torch ----> 3 from diffusers import StableDiffusionPipeline 4 import torch.nn.functional as nnf 5 import numpy as np ~/anaconda3/lib/python3.8/site-packages/diffusers/__init__.py in 24 ) 25 from .pipeline_utils import DiffusionPipeline ---> 26 from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline 27 from .schedulers import ( 28 DDIMScheduler, ~/anaconda3/lib/python3.8/site-packages/diffusers/pipelines/__init__.py in 9 10 if is_transformers_available(): ---> 11 from .latent_diffusion import LDMTextToImagePipeline 12 from .stable_diffusion import ( 13 StableDiffusionImg2ImgPipeline, ~/anaconda3/lib/python3.8/site-packages/diffusers/pipelines/latent_diffusion/__init__.py in 4 5 if is_transformers_available(): ----> 6 from .pipeline_latent_diffusion import LDMBertModel, LDMTextToImagePipeline ~/anaconda3/lib/python3.8/site-packages/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py in 7 import torch.utils.checkpoint 8 ----> 9 from transformers.activations import ACT2FN 10 from transformers.configuration_utils import PretrainedConfig 11 from transformers.modeling_outputs import BaseModelOutput ~/anaconda3/lib/python3.8/site-packages/transformers/__init__.py in 28 29 # Check the dependencies satisfy the minimal versions required. ---> 30 from . import dependency_versions_check 31 from .utils import ( 32 OptionalDependencyNotAvailable, ~/anaconda3/lib/python3.8/site-packages/transformers/dependency_versions_check.py in 15 16 from .dependency_versions_table import deps ---> 17 from .utils.versions import require_version, require_version_core 18 19 ~/anaconda3/lib/python3.8/site-packages/transformers/utils/__init__.py in 32 replace_return_docstrings, 33 ) ---> 34 from .generic import ( 35 ContextManagers, 36 ExplicitEnum, ~/anaconda3/lib/python3.8/site-packages/transformers/utils/generic.py in 31 32 if is_tf_available(): ---> 33 import tensorflow as tf 34 35 if is_flax_available(): ~/anaconda3/lib/python3.8/site-packages/tensorflow/__init__.py in 53 from ._api.v2 import autograph 54 from ._api.v2 import bitwise ---> 55 from ._api.v2 import compat 56 from ._api.v2 import config 57 from ._api.v2 import data ~/anaconda3/lib/python3.8/site-packages/tensorflow/_api/v2/compat/__init__.py in 37 import sys as _sys 38 ---> 39 from . import v1 40 from . import v2 41 from tensorflow.python.compat.compat import forward_compatibility_horizon ~/anaconda3/lib/python3.8/site-packages/tensorflow/_api/v2/compat/v1/__init__.py in 32 from . import autograph 33 from . import bitwise ---> 34 from . import compat 35 from . import config 36 from . import data ~/anaconda3/lib/python3.8/site-packages/tensorflow/_api/v2/compat/v1/compat/__init__.py in 37 import sys as _sys 38 ---> 39 from . import v1 40 from . import v2 41 from tensorflow.python.compat.compat import forward_compatibility_horizon ~/anaconda3/lib/python3.8/site-packages/tensorflow/_api/v2/compat/v1/compat/v1/__init__.py in 49 from tensorflow._api.v2.compat.v1 import layers 50 from tensorflow._api.v2.compat.v1 import linalg ---> 51 from tensorflow._api.v2.compat.v1 import lite 52 from tensorflow._api.v2.compat.v1 import logging 53 from tensorflow._api.v2.compat.v1 import lookup ~/anaconda3/lib/python3.8/site-packages/tensorflow/_api/v2/compat/v1/lite/__init__.py in 9 10 from . import constants ---> 11 from . import experimental 12 from tensorflow.lite.python.lite import Interpreter 13 from tensorflow.lite.python.lite import OpHint ~/anaconda3/lib/python3.8/site-packages/tensorflow/_api/v2/compat/v1/lite/experimental/__init__.py in 8 import sys as _sys 9 ---> 10 from . import authoring 11 from tensorflow.lite.python.analyzer import ModelAnalyzer as Analyzer 12 from tensorflow.lite.python.lite import OpResolverType ~/anaconda3/lib/python3.8/site-packages/tensorflow/_api/v2/compat/v1/lite/experimental/authoring/__init__.py in 8 import sys as _sys 9 ---> 10 from tensorflow.lite.python.authoring.authoring import compatible 11 12 del _print_function ~/anaconda3/lib/python3.8/site-packages/tensorflow/lite/python/authoring/authoring.py in 41 42 # pylint: disable=g-import-not-at-top ---> 43 from tensorflow.lite.python import convert 44 from tensorflow.lite.python import lite 45 from tensorflow.lite.python.metrics_wrapper import converter_error_data_pb2 ~/anaconda3/lib/python3.8/site-packages/tensorflow/lite/python/convert.py in 31 32 from tensorflow.lite.python import lite_constants ---> 33 from tensorflow.lite.python import util 34 from tensorflow.lite.python import wrap_toco 35 from tensorflow.lite.python.convert_phase import Component ~/anaconda3/lib/python3.8/site-packages/tensorflow/lite/python/util.py in 53 # pylint: disable=unused-import 54 try: ---> 55 from jax import xla_computation as _xla_computation 56 except ImportError: 57 _xla_computation = None ~/anaconda3/lib/python3.8/site-packages/jax/__init__.py in 90 # These submodules are separate because they are in an import cycle with 91 # jax and rely on the names imported above. ---> 92 from . import image 93 from . import lax 94 from . import nn ~/anaconda3/lib/python3.8/site-packages/jax/image/__init__.py in 16 17 # flake8: noqa: F401 ---> 18 from jax._src.image.scale import ( 19 resize, 20 ResizeMethod, ~/anaconda3/lib/python3.8/site-packages/jax/_src/image/scale.py in 18 19 from jax import jit ---> 20 from jax import lax 21 from jax import numpy as jnp 22 import numpy as np ~/anaconda3/lib/python3.8/site-packages/jax/lax/__init__.py in 322 while_p, 323 ) --> 324 from jax._src.lax.fft import ( 325 fft, 326 fft_p, ~/anaconda3/lib/python3.8/site-packages/jax/_src/lax/fft.py in 85 86 @partial(jit, static_argnums=1) ---> 87 def _rfft_transpose(t, fft_lengths): 88 # The transpose of RFFT can't be expressed only in terms of irfft. Instead of 89 # manually building up larger twiddle matrices (which would increase the ~/anaconda3/lib/python3.8/site-packages/jax/api.py in jit(fun, static_argnums, device, backend, donate_argnums) 179 """ 180 if FLAGS.experimental_cpp_jit and config.omnistaging_enabled: --> 181 return _cpp_jit(fun, static_argnums, device, backend, donate_argnums) 182 else: 183 return _python_jit(fun, static_argnums, device, backend, donate_argnums) ~/anaconda3/lib/python3.8/site-packages/jax/api.py in _cpp_jit(fun, static_argnums, device, backend, donate_argnums) 365 366 static_argnums_ = (0,) + tuple(i + 1 for i in static_argnums) --> 367 cpp_jitted_f = jax_jit.jit(fun, cache_miss, get_device_info, 368 get_jax_enable_x64, get_jax_disable_jit_flag, 369 static_argnums_) TypeError: jit(): incompatible function arguments. The following argument types are supported: 1. (fun: function, cache_miss: function, get_device: function, static_argnums: List[int], static_argnames: List[str] = [], donate_argnums: List[int] = [], cache: jaxlib.xla_extension.CompiledFunctionCache = None) -> object Invoked with: , .cache_miss at 0x7f44d1e18f70>, .get_device_info at 0x7f44d1e1e040>, .get_jax_enable_x64 at 0x7f44d1e1e0d0>, .get_jax_disable_jit_flag at 0x7f44d1e1e160>, (0, 2) ----------------------------------------------------------------------------------------------------- I am wondering what should I do to fix it?