huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
26.05k stars 5.37k forks source link

INT_MAX error #2306

Closed morrisalp closed 1 year ago

morrisalp commented 1 year ago

Describe the bug

Running the Stable Diffusion 2 generation pipeline with fp16, attention slicing and batch size 16 outputs error message: "RuntimeError: upsample_nearest_nhwc only supports output tensors with less than INT_MAX elements". EDIT: I can go up to batch size 14, but for >=15 I receive this error.

Reproduction

from diffusers import StableDiffusionPipeline
import torch

device = 'cuda'
model_id = "stabilityai/stable-diffusion-2"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to(device)
pipe.enable_attention_slicing()

g = torch.Generator(device=device).manual_seed(0)

pipe('a picture of a cat', generator=g, num_inference_steps=5, num_images_per_prompt=16)

Logs

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
File <timed exec>:1

File .../lib/python3.8/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File .../lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py:546, in StableDiffusionPipeline.__call__(self, prompt, height, width, num_inference_steps, guidance_scale, negative_prompt, num_images_per_prompt, eta, generator, latents, output_type, return_dict, callback, callback_steps)
    543                 callback(i, t, latents)
    545 # 8. Post-processing
--> 546 image = self.decode_latents(latents)
    548 # 9. Run safety checker
    549 image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)

File .../lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py:341, in StableDiffusionPipeline.decode_latents(self, latents)
    339 def decode_latents(self, latents):
    340     latents = 1 / 0.18215 * latents
--> 341     image = self.vae.decode(latents).sample
    342     image = (image / 2 + 0.5).clamp(0, 1)
    343     # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16

File .../lib/python3.8/site-packages/diffusers/models/vae.py:605, in AutoencoderKL.decode(self, z, return_dict)
    603     decoded = torch.cat(decoded_slices)
    604 else:
--> 605     decoded = self._decode(z).sample
    607 if not return_dict:
    608     return (decoded,)

File .../lib/python3.8/site-packages/diffusers/models/vae.py:577, in AutoencoderKL._decode(self, z, return_dict)
    575 def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
    576     z = self.post_quant_conv(z)
--> 577     dec = self.decoder(z)
    579     if not return_dict:
    580         return (dec,)

File .../lib/python3.8/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File .../lib/python3.8/site-packages/diffusers/models/vae.py:217, in Decoder.forward(self, z)
    215 # up
    216 for up_block in self.up_blocks:
--> 217     sample = up_block(sample)
    219 # post-process
    220 sample = self.conv_norm_out(sample)

File .../lib/python3.8/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File .../lib/python3.8/site-packages/diffusers/models/unet_2d_blocks.py:1695, in UpDecoderBlock2D.forward(self, hidden_states)
   1693 if self.upsamplers is not None:
   1694     for upsampler in self.upsamplers:
-> 1695         hidden_states = upsampler(hidden_states)
   1697 return hidden_states

File .../lib/python3.8/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File .../lib/python3.8/site-packages/diffusers/models/resnet.py:128, in Upsample2D.forward(self, hidden_states, output_size)
    125 # if `output_size` is passed we force the interpolation output
    126 # size and do not make use of `scale_factor=2`
    127 if output_size is None:
--> 128     hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
    129 else:
    130     hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")

File .../lib/python3.8/site-packages/torch/nn/functional.py:3922, in interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)
   3920     return torch._C._nn.upsample_nearest1d(input, output_size, scale_factors)
   3921 if input.dim() == 4 and mode == "nearest":
-> 3922     return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)
   3923 if input.dim() == 5 and mode == "nearest":
   3924     return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors)

RuntimeError: upsample_nearest_nhwc only supports output tensors with less than INT_MAX elements

System Info

python 3.8.13, dockerized jupyterlab 3.5.2, diffusers 0.11.0, CUDA 11.8, NVIDAI A5000 GPU

pip freeze output:

absl-py==1.2.0
accelerate==0.15.0
aiohttp==3.8.3
aiosignal==1.3.1
alabaster==0.7.12
anyio==3.6.1
apache-beam==2.44.0
apex==0.1
appdirs==1.4.4
argon2-cffi==21.3.0
argon2-cffi-bindings==21.2.0
asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1660605382950/work
astunparse==1.6.3
async-timeout==4.0.2
attrs==22.1.0
audioread==3.0.0
Babel==2.10.3
backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work
backports.functools-lru-cache @ file:///home/conda/feedstock_root/build_artifacts/backports.functools_lru_cache_1618230623929/work
beautifulsoup4 @ file:///home/conda/feedstock_root/build_artifacts/beautifulsoup4_1649463573192/work
bitsandbytes==0.35.4
bleach==5.0.1
BLEURT @ git+https://github.com/google-research/bleurt.git@cebe7e6f996b40910cfaa520a63db47807e3bf5c
blis @ file:///home/conda/feedstock_root/build_artifacts/cython-blis_1656314523915/work
Bottleneck==1.3.6
brotlipy @ file:///home/conda/feedstock_root/build_artifacts/brotlipy_1648854175163/work
cachetools==5.2.0
catalogue @ file:///home/conda/feedstock_root/build_artifacts/catalogue_1661366525041/work
certifi==2022.9.24
cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1656782821535/work
chardet @ file:///home/conda/feedstock_root/build_artifacts/chardet_1656142044710/work
charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1655906222726/work
click @ file:///home/conda/feedstock_root/build_artifacts/click_1651215152883/work
clip @ git+https://github.com/openai/CLIP.git@3702849800aa56e2223035bccd1c6ef91c704ca8
cloudpickle==2.2.0
codecov==2.1.12
colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1655412516417/work
conda-package-handling @ file:///home/conda/feedstock_root/build_artifacts/conda-package-handling_1663583601093/work
contourpy==1.0.5
coverage==6.5.0
crcmod==1.7
cryptography @ file:///home/conda/feedstock_root/build_artifacts/cryptography_1665535545125/work
cuda-python @ file:///rapids/cuda_python-11.7.0%2B0.g95a2041.dirty-cp38-cp38-linux_x86_64.whl
cudf @ file:///rapids/cudf-22.8.0a0%2B304.g6ca81bbc78.dirty-cp38-cp38-linux_x86_64.whl
cugraph @ file:///rapids/cugraph-22.8.0a0%2B132.g2daa31b6.dirty-cp38-cp38-linux_x86_64.whl
cuml @ file:///rapids/cuml-22.8.0a0%2B52.g73b8d00d0.dirty-cp38-cp38-linux_x86_64.whl
cupy-cuda118 @ file:///rapids/cupy_cuda118-11.0.0-cp38-cp38-linux_x86_64.whl
cycler==0.11.0
cymem @ file:///home/conda/feedstock_root/build_artifacts/cymem_1636053152744/work
Cython==0.29.32
dask @ file:///rapids/dask-2022.7.1-py3-none-any.whl
dask-cuda @ file:///rapids/dask_cuda-22.8.0a0%2B36.g9860cad-py3-none-any.whl
dask-cudf @ file:///rapids/dask_cudf-22.8.0a0%2B304.g6ca81bbc78.dirty-py3-none-any.whl
dataclasses @ file:///home/conda/feedstock_root/build_artifacts/dataclasses_1628958434797/work
datasets==2.8.0
debugpy==1.6.3
decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work
defusedxml==0.7.1
diffusers==0.11.0
dill==0.3.1.1
distributed @ file:///rapids/distributed-2022.7.1-py3-none-any.whl
dlib==19.24.0
docopt==0.6.2
docutils==0.17.1
en-core-web-trf @ https://github.com/explosion/spacy-models/releases/download/en_core_web_trf-3.2.0/en_core_web_trf-3.2.0-py3-none-any.whl
entrypoints==0.3
et-xmlfile==1.1.0
evaluate==0.4.0
executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1665301981797/work
expecttest==0.1.3
facenet-pytorch==2.5.2
fastavro==1.7.0
fasteners==0.18
fastjsonschema==2.16.2
fastrlock==0.8
filelock @ file:///home/conda/feedstock_root/build_artifacts/filelock_1660129891014/work
flake8==3.7.9
Flask==2.2.2
flatbuffers==22.12.6
fonttools==4.37.4
frozenlist==1.3.3
fsspec==2022.8.2
fst-pso==1.8.1
ftfy==6.1.1
functorch==0.3.0a0
future==0.18.2
FuzzyTM==2.0.5
fuzzywuzzy==0.18.0
gast==0.4.0
gensim==4.3.0
gitdb==4.0.10
GitPython==3.1.30
glob2==0.7
google-auth==2.12.0
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
graphsurgeon @ file:///workspace/TensorRT-8.5.0.12/graphsurgeon/graphsurgeon-0.4.6-py2.py3-none-any.whl
grpcio==1.49.1
h5py==3.7.0
hdfs==2.7.0
HeapDict==1.0.1
httplib2==0.20.4
huggingface-hub==0.11.0
hypothesis==4.50.8
idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1642433548627/work
imageio==2.25.0
imagesize==1.4.1
importlib-metadata==5.0.0
importlib-resources==5.10.0
iniconfig==1.1.1
iopath==0.1.10
ipykernel==6.16.0
ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1662481517711/work
ipython-genutils==0.2.0
ipywidgets==8.0.2
itsdangerous==2.1.2
jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1659959867326/work
Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1654302431367/work
joblib==1.2.0
json5==0.9.10
jsonschema==4.16.0
jupyter-core==4.11.1
jupyter-server==1.21.0
jupyter-tensorboard @ git+https://github.com/cliffwoolley/jupyter_tensorboard.git@ffa7e26138b82549453306e06b535a9ac36db17a
jupyter_client==7.4.2
jupyterlab==2.3.2
jupyterlab-pygments==0.2.2
jupyterlab-server==1.2.0
jupyterlab-widgets==3.0.3
jupytext==1.14.1
keras==2.10.0
Keras-Preprocessing==1.1.2
kiwisolver==1.4.4
langcodes @ file:///home/conda/feedstock_root/build_artifacts/langcodes_1636741340529/work
leven==1.0.4
Levenshtein==0.20.9
libarchive-c @ file:///home/conda/feedstock_root/build_artifacts/python-libarchive-c_1649436017468/work
libclang==14.0.6
librosa==0.9.2
lightning-utilities==0.4.2
llvmlite==0.39.1
lmdb==1.3.0
locket==1.0.0
Markdown==3.4.1
markdown-it-py==2.1.0
MarkupSafe @ file:///home/conda/feedstock_root/build_artifacts/markupsafe_1648737563195/work
matplotlib==3.6.2
matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1660814786464/work
mccabe==0.6.1
mdit-py-plugins==0.3.1
mdurl==0.1.2
miniful==0.0.6
mistune==2.0.4
mock @ file:///home/conda/feedstock_root/build_artifacts/mock_1648992799371/work
msgpack==1.0.4
multidict==6.0.3
multiprocess==0.70.14
murmurhash @ file:///home/conda/feedstock_root/build_artifacts/murmurhash_1636019583024/work
mwparserfromhell==0.6.4
nbclassic==0.4.5
nbclient==0.7.0
nbconvert==7.2.1
nbformat==5.7.0
nest-asyncio==1.5.6
networkx==2.6.3
nltk==3.7
nose==1.3.7
notebook==6.4.10
notebook-shim==0.1.0
numba==0.56.2
numpy @ file:///home/conda/feedstock_root/build_artifacts/numpy_1643958805350/work
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
nvidia-dali-cuda110==1.18.0
nvidia-pyindex==1.0.9
nvtx==0.2.5
oauthlib==3.2.1
objsize==0.6.1
onnx @ file:///opt/pytorch/pytorch/third_party/onnx
openai==0.25.0
openpyxl==3.0.10
opt-einsum==3.3.0
orjson==3.8.5
packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1637239678211/work
pandas==1.4.4
pandas-stubs==1.5.2.221213
pandocfilters==1.5.0
parrot @ git+https://github.com/PrithivirajDamodaran/Parrot_Paraphraser.git@720a87a1ee557d8ed8d9a021adbdd1dd5616c5f9
parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work
partd==1.3.0
pathy @ file:///home/conda/feedstock_root/build_artifacts/pathy_1656568808184/work
pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1602535608087/work
pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work
Pillow @ file:///tmp/pillow-simd
pkginfo @ file:///home/conda/feedstock_root/build_artifacts/pkginfo_1654782790443/work
pkgutil_resolve_name==1.3.10
pluggy==1.0.0
polygraphy==0.42.1
pooch==1.6.0
portalocker==2.5.1
preshed @ file:///home/conda/feedstock_root/build_artifacts/preshed_1636077712344/work
prettytable==3.4.1
prometheus-client==0.15.0
prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1662384672173/work
proto-plus==1.22.2
protobuf==3.19.6
psutil @ file:///home/conda/feedstock_root/build_artifacts/psutil_1662356143277/work
ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work
py==1.11.0
pyarrow @ file:///rapids/pyarrow-8.0.0-cp38-cp38-linux_x86_64.whl
pyasn1==0.4.8
pyasn1-modules==0.2.8
pybind11==2.10.0
pycocotools @ git+https://github.com/nvidia/cocoapi.git@142b17a358fdb5a31f9d5153d7a9f3f1cd385178#subdirectory=PythonAPI
pycodestyle==2.5.0
pycosat @ file:///home/conda/feedstock_root/build_artifacts/pycosat_1649384811940/work
pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1636257122734/work
pydantic @ file:///home/conda/feedstock_root/build_artifacts/pydantic_1636021149719/work
pydot==1.4.2
pyflakes==2.1.1
pyFUME==0.2.25
Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1660666458521/work
pylibcugraph @ file:///rapids/pylibcugraph-22.8.0a0%2B132.g2daa31b6.dirty-cp38-cp38-linux_x86_64.whl
pymongo==3.13.0
pynvml==11.4.1
pyOpenSSL @ file:///home/conda/feedstock_root/build_artifacts/pyopenssl_1643496850550/work
pyparsing @ file:///home/conda/feedstock_root/build_artifacts/pyparsing_1652235407899/work
pyrsistent==0.18.1
PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1661604839144/work
pytest==6.2.5
pytest-cov==4.0.0
pytest-pythonpath==0.7.4
python-dateutil==2.8.2
python-hostlist==1.22
python-Levenshtein==0.20.9
python-nvd3==0.15.0
python-slugify==6.1.2
pytorch-lightning==1.8.5.post0
pytorch-quantization==2.1.2
pytz @ file:///home/conda/feedstock_root/build_artifacts/pytz_1664798238822/work
PyWavelets==1.4.1
PyYAML @ file:///home/conda/feedstock_root/build_artifacts/pyyaml_1648757091578/work
pyzmq==24.0.1
raft @ file:///rapids/raft-22.8.0a0%2B70.g9070c30.dirty-cp38-cp38-linux_x86_64.whl
rapidfuzz==2.13.7
regex==2022.9.13
requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1656534056640/work
requests-oauthlib==1.3.1
resampy==0.4.2
responses==0.18.0
revtok @ git+git://github.com/jekbradbury/revtok.git@f1998b72a941d1e5f9578a66dc1c20b01913caab
rmm @ file:///rapids/rmm-22.8.0a0%2B62.gf6bf047.dirty-cp38-cp38-linux_x86_64.whl
rsa==4.9
ruamel-yaml-conda @ file:///home/conda/feedstock_root/build_artifacts/ruamel_yaml_1653464386701/work
sacremoses==0.0.53
safetensors==0.2.8
scikit-image==0.19.3
scikit-learn @ file:///rapids/scikit_learn-0.24.2-cp38-cp38-manylinux2010_x86_64.whl
scipy==1.10.0
seaborn==0.12.1
Send2Trash==1.8.0
sentence-transformers==2.2.2
sentencepiece==0.1.97
setfit==0.5.0
shellingham @ file:///home/conda/feedstock_root/build_artifacts/shellingham_1659638615822/work
simpful==2.9.0
six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
smart-open @ file:///home/conda/feedstock_root/build_artifacts/smart_open_1630238320325/work
smmap==5.0.0
sniffio==1.3.0
snowballstemmer==2.2.0
sortedcontainers==2.4.0
soundfile==0.11.0
soupsieve @ file:///home/conda/feedstock_root/build_artifacts/soupsieve_1658207591808/work
spacy @ file:///home/conda/feedstock_root/build_artifacts/spacy_1644657943105/work
spacy-alignments==0.9.0
spacy-legacy @ file:///home/conda/feedstock_root/build_artifacts/spacy-legacy_1660748275723/work
spacy-loggers @ file:///home/conda/feedstock_root/build_artifacts/spacy-loggers_1661365735520/work
spacy-transformers==1.1.7
Sphinx==5.2.3
sphinx-glpi-theme==0.3
sphinx-rtd-theme==1.0.0
sphinxcontrib-applehelp==1.0.2
sphinxcontrib-devhelp==1.0.2
sphinxcontrib-htmlhelp==2.0.0
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.5
srsly @ file:///home/conda/feedstock_root/build_artifacts/srsly_1638879568141/work
stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1664126450622/work
tabulate==0.9.0
tblib==1.7.0
tensorboard==2.10.1
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorboardX==2.5.1
tensorflow==2.10.0
tensorflow-estimator==2.10.0
tensorflow-io-gcs-filesystem==0.29.0
tensorrt @ file:///workspace/TensorRT-8.5.0.12/python/tensorrt-8.5.0.12-cp38-none-linux_x86_64.whl
termcolor==2.1.1
terminado==0.16.0
text-unidecode==1.3
tf-slim==1.1.0
thinc @ file:///home/conda/feedstock_root/build_artifacts/thinc_1638980259098/work
threadpoolctl==3.1.0
tifffile==2023.1.23.1
tinycss2==1.1.1
tokenizers==0.12.1
toml @ file:///home/conda/feedstock_root/build_artifacts/toml_1604308577558/work
tomli==2.0.1
toolz @ file:///home/conda/feedstock_root/build_artifacts/toolz_1657485559105/work
torch==1.13.1
torch-tensorrt @ file:///opt/pytorch/torch_tensorrt/py/dist/torch_tensorrt-1.3.0a0-cp38-cp38-linux_x86_64.whl
torchinfo==1.7.1
torchmetrics==0.11.0
torchtext==0.11.0a0
torchvision @ file:///opt/pytorch/vision
tornado==6.2
tqdm==4.64.1
traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1663005918942/work
transformer-engine @ file:///tmp/te_wheel/transformer_engine-0.1.0-cp38-cp38-linux_x86_64.whl
transformers==4.26.0
treelite @ file:///rapids/treelite-2.4.0-py3-none-manylinux2014_x86_64.whl
treelite-runtime @ file:///rapids/treelite_runtime-2.4.0-py3-none-manylinux2014_x86_64.whl
typer @ file:///home/conda/feedstock_root/build_artifacts/typer_1657029164904/work
types-pytz==2022.7.0.0
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1665144421445/work
ucx-py @ file:///rapids/ucx_py-0.27.0a0%2B29.ge9e81f8-cp38-cp38-linux_x86_64.whl
uff @ file:///workspace/TensorRT-8.5.0.12/uff/uff-0.6.9-py2.py3-none-any.whl
urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1658789158161/work
wasabi @ file:///home/conda/feedstock_root/build_artifacts/wasabi_1658931821849/work
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1600965781394/work
webencodings==0.5.1
websocket-client==1.4.1
Werkzeug==2.2.2
widgetsnbextension==4.0.3
wrapt==1.14.1
xgboost @ file:///rapids/xgboost-1.6.1-cp38-cp38-linux_x86_64.whl
xxhash==3.2.0
yarl==1.8.2
zict==2.2.0
zipp==3.9.0
zstandard==0.19.0
pcuenca commented 1 year ago

This sounds like this PyTorch issue that was recently resolved (in the codebase). Perhaps using pipe.enable_vae_slicing() would work; otherwise we could try to apply the workaround I described there.

morrisalp commented 1 year ago

I see pipe.enable_vae_slicing() avoids the issue.

github-actions[bot] commented 1 year ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.