xdit-project / xDiT

xDiT: A Scalable Inference Engine for Diffusion Transformers (DiTs) with Massive Parallelism
Apache License 2.0
658 stars 53 forks source link

FLUX Hopper benchmarking #324

Open antferdom opened 1 week ago

antferdom commented 1 week ago

Target: Measure the scalability of FLUX.1 on NVIDIA Hopper architecture (both H100 & H200) using different model parallelism strategies (see Flux.1 Performance Overview):

Models (order of relevance):

  1. FLUX.1-dev
  2. FLUX.1-schnell

Benchmarking Python utility script

Build xfuser command expression using Python instead of plain Bash scripts. We can then use this approach in convention with experiment config libraries like OmegaConf (e.g facebookresearch/lingua llm pre-training library). The following implementation also includes logging configuration for Torch compile (Inductor, Dynamo), valuable for inspecting graph breaks, tracing and compilation errors:

import logging
import os
import shlex
import subprocess
import sys
from pathlib import Path

import torch

DEBUG_TORCH_COMPILE = os.getenv("DEBUG_TORCH_COMPILE", "0") == "1"

# ref: https://github.com/pytorch/pytorch/blob/c81d4fd0a8ca826c165fe15f83398dbc4e20b523/docs/source/torch.compiler_troubleshooting.rst#L28
if DEBUG_TORCH_COMPILE:
    torch._logging.set_logs(dynamo=logging.INFO,
                            graph_code=True,
                            graph_breaks=True,
                            guards=True,
                            recompiles=True,
                            )
    #torch._dynamo.explain()

WD: str = Path(__file__).parent.absolute()
os.environ["PYTHONPATH"] = f"{WD}:{os.getenv('PYTHONPATH', '')}"
SCRIPT: str = os.path.join(WD, "flux_example.py")

MODEL_ID = "black-forest-labs/FLUX.1-dev"
INFERENCE_STEP = 28
WARMUP_STEPS = 3
max_sequence_length = 512
height = 1024
width = 1024
TASK_ARGS = f"--max-sequence-length {max_sequence_length} --height {height} --width {width}"
N_GPUS = 2
pipefusion_parallel_degree = 2
ulysses_degree = 1
ring_degree = 1
PARALLEL_ARGS = (
    f"--pipefusion_parallel_degree {pipefusion_parallel_degree} "
    f"--ulysses_degree {ulysses_degree} "
    f"--ring_degree {ring_degree} "
)
COMPILE_FLAG = "--use_torch_compile"

conda_binaries = Path(sys.executable).parent
torchrun_bin_path = os.path.join(conda_binaries, "torchrun")
command: str = (
                f"{sys.executable} -m torch.distributed.run --nproc_per_node={N_GPUS} {SCRIPT} "
                f"--model {MODEL_ID} "
                f"{PARALLEL_ARGS} "
                f"{TASK_ARGS} "
                f"--num_inference_steps {INFERENCE_STEP} "
                f"--warmup_steps {WARMUP_STEPS} "
                f"--prompt \"brown dog laying on the ground with a metal bowl in front of him.\" "
                f"{COMPILE_FLAG}"
            )
print(command)
print(shlex.split(command))
subprocess.run(shlex.split(command))

Minor change in examples/flux_example.py for setting xFuserFluxPipeline max_sequence_length using input_config instead of manual hard-coding value . This modification solves Torch compilation errors when setting max_seq_length=512 for FLUX dev.

    output = pipe(
        height=input_config.height,
        width=input_config.width,
        prompt=input_config.prompt,
        num_inference_steps=input_config.num_inference_steps,
        output_type=input_config.output_type,
        max_sequence_length=input_config.max_sequence_length,
        guidance_scale=0.0,
        generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
    )

@feifeibear configurations for parallel layouts

feifeibear commented 6 days ago

Perfect! Could you please give some results on your H100 benchmark?

antferdom commented 6 days ago

Currently having issues with Flash Attention v3 Hopper interface and SP-Ulysses and SP-ring. I will forward this to #319 .