pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.49k stars 483 forks source link

torch.layer_norm causes RESOURCE_EXHAUSTED #8395

Open radna0 opened 3 days ago

radna0 commented 3 days ago

πŸ› Bug

I'm trying to do inference with CogVideoX1.5-5B model. and everything is running fine until the layer norm step.

To Reproduce

Steps to reproduce the behavior: xka_pipeline_cogvideox.txt

  1. clone the CogVideoX repo

  2. pip install -r requirements.txt, will have to reinstall pytorch xla

  3. Install the following, diffusers must be installed from source.

    diffusers>=0.32.0dev (or from source)
    transformers>=4.46.2
    accelerate>=1.1.1
    imageio-ffmpeg>=0.5.1
    pip install --upgrade transformers accelerate diffusers imageio-ffmpeg
    pip install git+https://github.com/huggingface/diffusers
  4. See inference/cli_demo.py, go to source of CogVideoXPipeline, and replace the content with xka_pipeline_cogvideox.txt

  5. Comment this line pipe.enable_sequential_cpu_offload() image

  6. Finally run inference/cli_demo.py

The error:

kojoe@t1v-n-3ad02607-w-0:~/CogVideo$ python3.10 inference/cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVideoX1.5-5b --generate_type "t2v"
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
WARNING:bitsandbytes.cextension:The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.
Loading checkpoint shards: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 4/4 [00:01<00:00,  2.06it/s]
Loading pipeline components...: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5/5 [00:03<00:00,  1.50it/s]
  0%|                                                                          | 0/50 [00:00<?, ?it/s]tcmalloc: large alloc 1092296704 bytes == 0x2f3090000 @  0x7f0ccb8d6680 0x7f0ccb8f7824 0x7f0b18512279 0x7f0b112100aa 0x7f0b1120fffe 0x7f0b111bd51c 0x7f0b111bd3e4 0x7f0b1055dfeb 0x7f0b10568045 0x7f0b105baca8 0x7f0b107debdd 0x7f0b107e1537 0x7f0b17fb1c73 0x7f0b17fb7e36 0x7f0b17fc09a5 0x7f0b18156963 0x7f0ccb6a8609 0x7f0ccb7e2353
  0%|                                                                          | 0/50 [08:56<?, ?it/s]
Traceback (most recent call last):
  File "/home/kojoe/CogVideo/inference/cli_demo.py", line 258, in <module>
    generate_video(
  File "/home/kojoe/CogVideo/inference/cli_demo.py", line 148, in generate_video
    video_generate = pipe(
  File "/home/kojoe/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/kojoe/.local/lib/python3.10/site-packages/diffusers/pipelines/cogvideo/pipeline_cogvideox.py", line 736, in __call__
    noise_pred = self.transformer(
  File "/home/kojoe/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/kojoe/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/kojoe/.local/lib/python3.10/site-packages/torch_xla/experimental/spmd_fully_sharded_data_parallel.py", line 160, in forward
    output = self.module(*args, **kwargs)
  File "/home/kojoe/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/kojoe/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
    return inner()
  File "/home/kojoe/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1790, in inner
    result = forward_call(*args, **kwargs)
  File "/home/kojoe/.local/lib/python3.10/site-packages/diffusers/models/transformers/cogvideox_transformer_3d.py", line 503, in forward
    hidden_states, encoder_hidden_states = block(
  File "/home/kojoe/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/kojoe/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/kojoe/.local/lib/python3.10/site-packages/diffusers/models/transformers/cogvideox_transformer_3d.py", line 132, in forward
    attn_hidden_states, attn_encoder_hidden_states = self.attn1(
  File "/home/kojoe/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/kojoe/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/kojoe/.local/lib/python3.10/site-packages/diffusers/models/attention_processor.py", line 530, in forward
    return self.processor(
  File "/home/kojoe/.local/lib/python3.10/site-packages/diffusers/models/attention_processor.py", line 2285, in __call__
    query = attn.norm_q(query)
  File "/home/kojoe/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/kojoe/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/kojoe/.local/lib/python3.10/site-packages/torch/nn/modules/normalization.py", line 217, in forward
    return F.layer_norm(
  File "/home/kojoe/.local/lib/python3.10/site-packages/torch/nn/functional.py", line 2910, in layer_norm
    return torch.layer_norm(
RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: Error loading program: Attempting to reserve 5.67G at the bottom of memory. That was not possible. There are 111.70M free, 0B reserved, and 111.70M reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).

Expected behavior

The model should have run normally, this seems to be an error specifically related to the layer norm.

Screenshot of Memory info at the time of the error happening image

Environment

To install

pip3 install torch==2.6.0.dev20241105+cpu torchvision torchaudio==2.5.0.dev20241105+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241105-cp310-cp310-linux_x86_64.whl

Additional context