replicate / cog-flux

Cog inference for flux models
https://replicate.com/black-forest-labs/flux-dev
Apache License 2.0
272 stars 28 forks source link

Skip quantization for GPUs with <8.9 Compute Capability #27

Closed Averylamp closed 3 weeks ago

Averylamp commented 3 weeks ago

The current quantization pipeline only works for GPUs with >=8.9 Compute Capability (H100, L40, L4, RTX 6000 ADA), so it currently fails to run on every GPU available on Replicate besides H100s (which I currently don't have access too :pray:)

This PR skips creating the pipeline for the FP8 model if a Compute Capability of < 8.9 is detected.

The error upon compilation of FP8 quantized pipeline is #23

There are other quantization methods that <8.9 would work with such as optimum quanto, but opted to just throw an error if go_fast is checked and a FP8 incompatible model is used.

To use other quantization, likely the it'd be faster to just pull a pre-quantized checkpoint rather than quantize on the fly.
Also for <8.9 CC older GPUs there's limited benefit in using FP8 quantization besides lower VRAM, but performance in terms of inference time is minimal (~12s vs. ~15s generation time for Flux-Dev on an A100 from some quick testing).

Please let me know if you'd like anything changed and I'd be happy to

daanelson commented 3 weeks ago

@Averylamp thanks for this! we definitely should disable fp8 on devices that can't run it. just merged a separate pr to address this in #29, check that out.

Averylamp commented 3 weeks ago

Awesome thanks. You can probably close #23 too. Is there a way to get H100 access as a Replicate customer OOC as well? Or at least access to FP8 capable GPUs?