Parallel inference for black-forest-labs
' FLUX
model.
piflux
is a parallel inference optimization library for black-forest-labs
' FLUX
model. It works with torch.distributed
to utilize more than 1 NVIDIA GPU to parallelize the inference. So that the time needed for generating one image can be reduced more.
This library is only for demonstrative perpose and is not intended for production use. Whether it can be combined with other optimization techniques like torch.compile
is not tested.
git clone https://github.com/chengzeyi/piflux.git
cd piflux
git submodule update --init --recursive
pip3 install 'setuptools>=64,<70' 'setuptools_scm>=8'
pip3 install -e '.[dev]'
# Code formatting and linting
pip3 install pre-commit
pre-commit install
pre-commit run --all-files
N_GPUS=2
torchrun --nproc_per_node=$N_GPUS examples/run_flux.py --print-output
piflux
replicates the model on each GPU that it uses. Its working state is partitioned into two parts:
torch.nn.functional.scaled_dot_product_attention
is called, we call all gather
to get the full KV and store them into a cache. The output noise prediction is then collected and concatenated to form the final output.gather
to get the full KV. Instead, we circle through the the input latent parts and assign them to different GPUs. Therefore, when torch.nn.functional.scaled_dot_product_attention
is called, we can use the cached KV from the previous steps and update different parts of the cached KV periodically. The output noise prediction is then collected and concatenated to form the final output.By doing so, piflux
minimizes the need for inter-GPU Communication so that it can achieve better performance compared to other implementations.
Configuration | Wall Time | Iterations per Second | Output Image |
---|---|---|---|
1 x H100 (baseline) | 6.37s | 4.52 | |
2 x H100 | 3.74s | 8.16 | |
2 x H100 plus a Mysterious Compiler | 1.74s | 18.64 |