xdit-project / xDiT

xDiT: A Scalable Inference Engine for Diffusion Transformers (DiTs) on multi-GPU Clusters
Apache License 2.0
480 stars 40 forks source link

RTX 4090, flux model, out of memory; approach is not compatible with quantization #218

Closed csdY123 closed 3 weeks ago

csdY123 commented 3 weeks ago

` @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], engine_config: EngineConfig, **kwargs, ): dtype = torch.bfloat16 bfl_repo = "black-forest-labs/FLUX.1-dev" print("Loading model...") scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("models--black-forest-labs--FLUX.1-dev/snapshots/01aa605f2c300568dd6515476f04565a954fcb59", subfolder="scheduler") text_encoder = CLIPTextModel.from_pretrained("models--openai--clip-vit-large-patch14/snapshots/32bd64288804d66eefd0ccbe215aa642df71cc41", torch_dtype=dtype) tokenizer = CLIPTokenizer.from_pretrained("models--openai--clip-vit-large-patch14/snapshots/32bd64288804d66eefd0ccbe215aa642df71cc41", torch_dtype=dtype) text_encoder_2 = T5EncoderModel.from_pretrained("models--black-forest-labs--FLUX.1-dev/snapshots/01aa605f2c300568dd6515476f04565a954fcb59", subfolder="text_encoder_2", torch_dtype=dtype) tokenizer_2 = T5TokenizerFast.from_pretrained("models--black-forest-labs--FLUX.1-dev/snapshots/01aa605f2c300568dd6515476f04565a954fcb59", subfolder="tokenizer_2", torch_dtype=dtype) vae = AutoencoderKL.from_pretrained("models--black-forest-labs--FLUX.1-dev/snapshots/01aa605f2c300568dd6515476f04565a954fcb59", subfolder="vae", torch_dtype=dtype) transformer = FluxTransformer2DModel.from_pretrained("models--black-forest-labs--FLUX.1-dev/snapshots/01aa605f2c300568dd6515476f04565a954fcb59", subfolder="transformer", torch_dtype=dtype)

    print("Loading in the model")
    # Experimental: Try this to load in 4-bit for <16GB cards.
    #
    # from optimum.quanto import qint4
    # quantize(transformer, weights=qint4, exclude=["proj_out", "x_embedder", "norm_out", "context_embedder"])
    # freeze(transformer)
    quantize(transformer, weights=qfloat8)
    freeze(transformer)

    quantize(text_encoder_2, weights=qfloat8)
    freeze(text_encoder_2)

    print("Loading in the model")

    import time
    start_time = time.time()

    pipe = FluxPipeline(
        scheduler=scheduler,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        text_encoder_2=None,
        tokenizer_2=tokenizer_2,
        vae=vae,
        transformer=None,
    )
    pipe.text_encoder_2 = text_encoder_2
    pipe.transformer = transformer
    # pipe.enable_model_cpu_offload()

    # pipeline = FluxPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs)
    return cls(pipe, engine_config)`

report error ` INFO 08-23 16:06:18 [base_model.py:80] [RANK 0] Wrapping transformer_blocks.0.attn in model class FluxTransformer2DModel with xFuserAttentionWrapper INFO 08-23 16:06:18 [base_pipeline.py:204] Transformer backbone found, paralleling transformer... INFO 08-23 16:06:18 [base_model.py:80] [RANK 1] Wrapping transformer_blocks.0.attn in model class FluxTransformer2DModel with xFuserAttentionWrapper rank0: Traceback (most recent call last): rank0: File "/ML-A800/team/mm/chensenda/xDiT/./examples/flux_example.py", line 67, in

rank0: File "/ML-A800/team/mm/chensenda/xDiT/./examples/flux_example.py", line 19, in main rank0: pipe = xFuserFluxPipeline.from_pretrained( rank0: File "/ML-A800/team/mm/chensenda/anaconda3/envs/XDIT/lib/python3.10/site-packages/xfuser-0.2-py3.10.egg/xfuser/model_executor/pipelines/pipeline_flux.py", line 120, in from_pretrained rank0: File "/ML-A800/team/mm/chensenda/anaconda3/envs/XDIT/lib/python3.10/site-packages/xfuser-0.2-py3.10.egg/xfuser/model_executor/pipelines/base_pipeline.py", line 58, in init rank0: File "/ML-A800/team/mm/chensenda/anaconda3/envs/XDIT/lib/python3.10/site-packages/xfuser-0.2-py3.10.egg/xfuser/model_executor/pipelines/base_pipeline.py", line 206, in _convert_transformer_backbone rank0: File "/ML-A800/team/mm/chensenda/anaconda3/envs/XDIT/lib/python3.10/site-packages/xfuser-0.2-py3.10.egg/xfuser/model_executor/models/transformers/transformer_flux.py", line 29, in init rank0: File "/ML-A800/team/mm/chensenda/anaconda3/envs/XDIT/lib/python3.10/site-packages/xfuser-0.2-py3.10.egg/xfuser/model_executor/models/transformers/base_transformer.py", line 29, in init rank0: File "/ML-A800/team/mm/chensenda/anaconda3/envs/XDIT/lib/python3.10/site-packages/xfuser-0.2-py3.10.egg/xfuser/model_executor/models/transformers/base_transformer.py", line 49, in _convert_transformer_for_parallel rank0: File "/ML-A800/team/mm/chensenda/anaconda3/envs/XDIT/lib/python3.10/site-packages/xfuser-0.2-py3.10.egg/xfuser/model_executor/models/base_model.py", line 97, in _wrap_layers rank0: File "/ML-A800/team/mm/chensenda/anaconda3/envs/XDIT/lib/python3.10/site-packages/xfuser-0.2-py3.10.egg/xfuser/model_executor/layers/attention_processor.py", line 106, in init rank0: File "/ML-A800/team/mm/chensenda/anaconda3/envs/XDIT/lib/python3.10/site-packages/xfuser-0.2-py3.10.egg/xfuser/model_executor/layers/attention_processor.py", line 58, in init rank0: File "/ML-A800/team/mm/chensenda/anaconda3/envs/XDIT/lib/python3.10/site-packages/optimum/quanto/tensor/qtensor.py", line 93, in torch_function rank0: return func(*args, **kwargs) rank0: File "/ML-A800/team/mm/chensenda/anaconda3/envs/XDIT/lib/python3.10/site-packages/optimum/quanto/tensor/qbytes.py", line 130, in torch_dispatch rank0: return qdispatch(*args, **kwargs) rank0: File "/ML-A800/team/mm/chensenda/anaconda3/envs/XDIT/lib/python3.10/site-packages/optimum/quanto/tensor/qbytesops.py", line 121, in copy rank0: assert dest.qtype == src.qtype rank0: AttributeError: 'Tensor' object has no attribute 'qtype'. Did you mean: 'dtype'? rank1: Traceback (most recent call last): rank1: File "/ML-A800/team/mm/chensenda/xDiT/./examples/flux_example.py", line 67, in

rank1: File "/ML-A800/team/mm/chensenda/xDiT/./examples/flux_example.py", line 19, in main rank1: pipe = xFuserFluxPipeline.from_pretrained( rank1: File "/ML-A800/team/mm/chensenda/anaconda3/envs/XDIT/lib/python3.10/site-packages/xfuser-0.2-py3.10.egg/xfuser/model_executor/pipelines/pipeline_flux.py", line 120, in from_pretrained rank1: File "/ML-A800/team/mm/chensenda/anaconda3/envs/XDIT/lib/python3.10/site-packages/xfuser-0.2-py3.10.egg/xfuser/model_executor/pipelines/base_pipeline.py", line 58, in init rank1: File "/ML-A800/team/mm/chensenda/anaconda3/envs/XDIT/lib/python3.10/site-packages/xfuser-0.2-py3.10.egg/xfuser/model_executor/pipelines/base_pipeline.py", line 206, in _convert_transformer_backbone rank1: File "/ML-A800/team/mm/chensenda/anaconda3/envs/XDIT/lib/python3.10/site-packages/xfuser-0.2-py3.10.egg/xfuser/model_executor/models/transformers/transformer_flux.py", line 29, in init rank1: File "/ML-A800/team/mm/chensenda/anaconda3/envs/XDIT/lib/python3.10/site-packages/xfuser-0.2-py3.10.egg/xfuser/model_executor/models/transformers/base_transformer.py", line 29, in init rank1: File "/ML-A800/team/mm/chensenda/anaconda3/envs/XDIT/lib/python3.10/site-packages/xfuser-0.2-py3.10.egg/xfuser/model_executor/models/transformers/base_transformer.py", line 49, in _convert_transformer_for_parallel rank1: File "/ML-A800/team/mm/chensenda/anaconda3/envs/XDIT/lib/python3.10/site-packages/xfuser-0.2-py3.10.egg/xfuser/model_executor/models/base_model.py", line 97, in _wrap_layers rank1: File "/ML-A800/team/mm/chensenda/anaconda3/envs/XDIT/lib/python3.10/site-packages/xfuser-0.2-py3.10.egg/xfuser/model_executor/layers/attention_processor.py", line 106, in init rank1: File "/ML-A800/team/mm/chensenda/anaconda3/envs/XDIT/lib/python3.10/site-packages/xfuser-0.2-py3.10.egg/xfuser/model_executor/layers/attention_processor.py", line 58, in init rank1: File "/ML-A800/team/mm/chensenda/anaconda3/envs/XDIT/lib/python3.10/site-packages/optimum/quanto/tensor/qtensor.py", line 93, in torch_function rank1: return func(*args, **kwargs) rank1: File "/ML-A800/team/mm/chensenda/anaconda3/envs/XDIT/lib/python3.10/site-packages/optimum/quanto/tensor/qbytes.py", line 130, in torch_dispatch rank1: return qdispatch(*args, *kwargs) rank1: File "/ML-A800/team/mm/chensenda/anaconda3/envs/XDIT/lib/python3.10/site-packages/optimum/quanto/tensor/qbytesops.py", line 121, in copy rank1: assert dest.qtype == src.qtype rank1: AttributeError: 'Tensor' object has no attribute 'qtype'. Did you mean: 'dtype'? W0823 16:06:21.585000 140016451217216 torch/distributed/elastic/multiprocessing/api.py:858] Sending process 2814441 closing signal SIGTERM E0823 16:06:21.649000 140016451217216 torch/distributed/elastic/multiprocessing/api.py:833] failed (exitcode: 1) local_rank: 1 (pid: 2814442) of binary: /ML-A800/team/mm/chensenda/anaconda3/envs/XDIT/bin/python Traceback (most recent call last): File "/ML-A800/team/mm/chensenda/anaconda3/envs/XDIT/bin/torchrun", line 33, in sys.exit(load_entry_point('torch==2.4.0+cu118', 'console_scripts', 'torchrun')()) File "/ML-A800/team/mm/chensenda/anaconda3/envs/XDIT/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/init.py", line 348, in wrapper return f(args, **kwargs) File "/ML-A800/team/mm/chensenda/anaconda3/envs/XDIT/lib/python3.10/site-packages/torch/distributed/run.py", line 901, in main run(args) File "/ML-A800/team/mm/chensenda/anaconda3/envs/XDIT/lib/python3.10/site-packages/torch/distributed/run.py", line 892

feifeibear commented 3 weeks ago

pull the latest main branch.

add --enable_sequential_cpu_offload to cmd args if you run on a single GPU.

image 0 saved to ./results/flux_result_dp1_cfg1_ulysses1_ringNone_tp1_pp1_patchNone_0.png epoch time: 20.51 sec, memory: 2.56510976 GB

You can run it with 2.5GB VRAM

change example/run.sh

torchrun --nproc_per_node=$N_GPUS ./examples/$SCRIPT \
--model $MODEL_ID \
$PARALLEL_ARGS \
$TASK_ARGS \
$PIPEFUSION_ARGS \
$OUTPUT_ARGS \
--num_inference_steps $INFERENCE_STEP \
--warmup_steps 0 \
--prompt "A small dog" \
$CFG_ARGS \
$PARALLLEL_VAE \
**--enable_sequential_cpu_offload**
csdY123 commented 3 weeks ago

pull the latest main branch.

add --enable_sequential_cpu_offload to cmd args if you run on a single GPU.

image 0 saved to ./results/flux_result_dp1_cfg1_ulysses1_ringNone_tp1_pp1_patchNone_0.png epoch time: 20.51 sec, memory: 2.56510976 GB

You can run it with 2.5GB VRAM

change example/run.sh

torchrun --nproc_per_node=$N_GPUS ./examples/$SCRIPT \
--model $MODEL_ID \
$PARALLEL_ARGS \
$TASK_ARGS \
$PIPEFUSION_ARGS \
$OUTPUT_ARGS \
--num_inference_steps $INFERENCE_STEP \
--warmup_steps 0 \
--prompt "A small dog" \
$CFG_ARGS \
$PARALLLEL_VAE \
**--enable_sequential_cpu_offload**

This method does not work. Using 4090 inference still reports out of memory error

feifeibear commented 3 weeks ago

I see. Even use enable_sequential_cpu_offload, it is required to load full model into GPU first (33GB). That is why you meet OOM on 4090.

I believe we should hack the enable_sequential_cpu_offload().

Would like to help us work on this feature?

feifeibear commented 3 weeks ago

@csdY123 I have fixed the problem #221 . pull the latest main branch