aredden / flux-fp8-api

Flux diffusion model implementation using quantized fp8 matmul & remaining layers use faster half precision accumulate, which is ~2x faster on consumer devices.
Apache License 2.0
210 stars 22 forks source link

Issue: torch._scaled_mm RuntimeError on RTX 6000 (with runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04) #30

Closed veyorokon closed 3 weeks ago

veyorokon commented 3 weeks ago

Description
When using the flux-fp8-api with configuration .configs/config-dev-1-RTX6000ADA.json on an RTX 6000, I receive a RuntimeError regarding unsupported torch._scaled_mm due to compute capability requirements. My environment uses the Docker image runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04.

Docker Image:
runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04

Error Details

RuntimeError: torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+

Relevant Configuration Path

config_path = ".configs/config-dev-1-RTX6000ADA.json"

Has anyone encountered this before?

aredden commented 3 weeks ago

Oh- it could be that you're using an RTX 6000 - non-ada, which is different than the RTX 6000 ADA. They have a similar name, but one is ada generation, and the other is from last gen, which would have compute capability 8.6.

veyorokon commented 3 weeks ago

gotcha - was wondering about that possibility - ty