huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.06k stars 5.18k forks source link

Sequential offloading bug with Stable Audio #8989

Open ylacombe opened 1 month ago

ylacombe commented 1 month ago

Describe the bug

Sequential offloading doesn't work when using pytest, but does seem to work outside of tests. This is an issue, because we can't properly test sequential offloading on Stable Audio.

Additional information: I've managed to pinpoint where the first NaN appears. It's in the autoencoder level when passing through the decoder's conv1. It might has to do with the fact that we're using weight_norm, but at this point, I really don't know since it works outside of the testing file.

Reproduction

Offloading with NaN

test_sequential_cpu_offload_forward_pass and test_sequential_offload_forward_pass_twice

Working offloading

import torch
from diffusers import StableAudioPipeline, StableAudioProjectionModel, StableAudioDiTModel, EDMDPMSolverMultistepScheduler, AutoencoderOobleck
from transformers import T5EncoderModel, T5Tokenizer

def get_dummy_components():
    torch.manual_seed(0)
    transformer = StableAudioDiTModel(
        sample_size=32,
        in_channels=6,
        num_layers=2,
        attention_head_dim=4,
        num_key_value_attention_heads=2,
        out_channels=6,
        cross_attention_dim=4,
        time_proj_dim=8,
        global_states_input_dim=48,
        cross_attention_input_dim=24
    )
    scheduler = EDMDPMSolverMultistepScheduler(
        solver_order=2,
        prediction_type="v_prediction",
        noise_preconditioning_strategy="atan",
        sigma_data=1.0,
        algorithm_type="sde-dpmsolver++",
        sigma_schedule="exponential",
        noise_sampling_strategy="brownian_tree",
    )
    torch.manual_seed(0)
    vae = AutoencoderOobleck(
        encoder_hidden_size=12,
        downsampling_ratios=[1, 2],
        decoder_channels=12,
        decoder_input_channels=6,
        audio_channels=2,
        channel_multiples=[2, 4],
        sampling_rate=32,
    )
    torch.manual_seed(0)
    t5_repo_id = "hf-internal-testing/tiny-random-T5ForConditionalGeneration"
    text_encoder = T5EncoderModel.from_pretrained(t5_repo_id)
    tokenizer = T5Tokenizer.from_pretrained(t5_repo_id, truncation=True, model_max_length=25)

    torch.manual_seed(0)
    projection_model = StableAudioProjectionModel(
        text_encoder_dim=text_encoder.config.d_model,
        conditioning_dim=24,
        min_value=0,
        max_value=256,
    )
    components = {
        "transformer": transformer,
        "scheduler": scheduler,
        "vae": vae,
        "text_encoder": text_encoder,
        "tokenizer": tokenizer,
        "projection_model": projection_model,
    }
    return components

def get_dummy_inputs():
    generator = torch.manual_seed(0)
    inputs = {
        "prompt": "A hammer hitting a wooden surface",
        "generator": generator,
        "num_inference_steps": 2,
        "guidance_scale": 6.0,
    }
    return inputs

components = get_dummy_components()
pipeline = StableAudioPipeline(**components)
pipeline.enable_sequential_cpu_offload(device="cuda")

inputs = get_dummy_inputs()
pipeline(**inputs)

Logs

No response

System Info

- 🤗 Diffusers version: 0.30.0.dev0
- Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.29
- Running on Google Colab?: No
- Python version: 3.8.10
- PyTorch version (GPU?): 2.3.1+cu121 (True)
- Flax version (CPU?/GPU?/TPU?): 0.7.2 (cpu)
- Jax version: 0.4.13
- JaxLib version: 0.4.13
- Huggingface_hub version: 0.23.4
- Transformers version: 4.42.0.dev0
- Accelerate version: 0.31.0
- PEFT version: 0.11.1
- Bitsandbytes version: not installed
- Safetensors version: 0.4.3
- xFormers version: not installed
- Accelerator: NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA DGX Display, 4096 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: no

Who can help?

cc @yiyixuxu @sayakpaul

sayakpaul commented 1 month ago

Cc: @SunMarc