replicate / cog-flux

Cog inference for flux models
https://replicate.com/black-forest-labs/flux-dev
Apache License 2.0
272 stars 28 forks source link

torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+ #23

Open micos7 opened 1 month ago

micos7 commented 1 month ago

I uploaded the rep to a model i own and when I run a prediction I get this:

Booting model flux-dev
Detected GPU: NVIDIA A40
downloading url:  https://weights.replicate.delivery/default/sdxl/safety-1.0.tar
downloading to:  /src/safety-cache
2024-09-25T10:35:15Z | INFO  | [ Initiating ] chunk_size=150M dest=/src/safety-cache url=https://weights.replicate.delivery/default/sdxl/safety-1.0.tar
2024-09-25T10:35:16Z | INFO  | [ Complete ] dest=/src/safety-cache size="608 MB" total_elapsed=0.473s url=https://weights.replicate.delivery/default/sdxl/safety-1.0.tar
downloading took:  0.5735323429107666
Loading Safety Checker to GPU
Loading Falcon safety checker...
downloading url:  https://weights.replicate.delivery/default/falconai/nsfw-image-detection.tar
downloading to:  /src/falcon-cache
2024-09-25T10:35:16Z | INFO  | [ Initiating ] chunk_size=150M dest=/src/falcon-cache url=https://weights.replicate.delivery/default/falconai/nsfw-image-detection.tar
2024-09-25T10:35:17Z | INFO  | [ Complete ] dest=/src/falcon-cache size="343 MB" total_elapsed=0.342s url=https://weights.replicate.delivery/default/falconai/nsfw-image-detection.tar
downloading took:  0.4393315315246582
downloading url:  https://weights.replicate.delivery/default/official-models/flux/t5/t5-v1_1-xxl.tar
downloading to:  ./model-cache/t5
2024-09-25T10:35:17Z | INFO  | [ Initiating ] chunk_size=150M dest=./model-cache/t5 url=https://weights.replicate.delivery/default/official-models/flux/t5/t5-v1_1-xxl.tar
2024-09-25T10:35:39Z | INFO  | [ Complete ] dest=./model-cache/t5 size="9.5 GB" total_elapsed=21.727s url=https://weights.replicate.delivery/default/official-models/flux/t5/t5-v1_1-xxl.tar
downloading took:  22.41786813735962
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  6.74it/s]
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  8.32it/s]
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  8.03it/s]
downloading url:  https://weights.replicate.delivery/default/official-models/flux/clip/clip-vit-large-patch14.tar
downloading to:  ./model-cache/clip
2024-09-25T10:35:45Z | INFO  | [ Initiating ] chunk_size=150M dest=./model-cache/clip url=https://weights.replicate.delivery/default/official-models/flux/clip/clip-vit-large-patch14.tar
2024-09-25T10:35:45Z | INFO  | [ Complete ] dest=./model-cache/clip size="248 MB" total_elapsed=0.284s url=https://weights.replicate.delivery/default/official-models/flux/clip/clip-vit-large-patch14.tar
downloading took:  0.3734302520751953
Init model
downloading url:  https://weights.replicate.delivery/default/official-models/flux/dev/dev.sft
downloading to:  ./model-cache/dev/dev.sft
2024-09-25T10:35:45Z | INFO  | [ Initiating ] chunk_size=150M dest=./model-cache/dev/dev.sft url=https://weights.replicate.delivery/default/official-models/flux/dev/dev.sft
2024-09-25T10:36:08Z | INFO  | [ Complete ] dest=./model-cache/dev/dev.sft size="24 GB" total_elapsed=22.286s url=https://weights.replicate.delivery/default/official-models/flux/dev/dev.sft
downloading took:  22.616339206695557
Loading checkpoint
Init AE
downloading url:  https://weights.replicate.delivery/default/official-models/flux/ae/ae.sft
downloading to:  ./model-cache/ae/ae.sft
2024-09-25T10:36:08Z | INFO  | [ Initiating ] chunk_size=150M dest=./model-cache/ae/ae.sft url=https://weights.replicate.delivery/default/official-models/flux/ae/ae.sft
2024-09-25T10:36:09Z | INFO  | [ Complete ] dest=./model-cache/ae/ae.sft size="335 MB" total_elapsed=0.339s url=https://weights.replicate.delivery/default/official-models/flux/ae/ae.sft
downloading took:  0.461498498916626
Running warmups for compile...
Using seed: 10
  0%|          | 0/12 [00:00<?, ?it/s]
0%|          | 0/12 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/root/.pyenv/versions/3.11.10/lib/python3.11/site-packages/cog/server/worker.py", line 312, in _setup
run_setup(self._predictor)
File "/root/.pyenv/versions/3.11.10/lib/python3.11/site-packages/cog/predictor.py", line 89, in run_setup
predictor.setup()
File "/src/predict.py", line 487, in setup
self.base_setup("flux-dev", compile_fp8=True)
File "/src/predict.py", line 164, in base_setup
self.fp8_pipe = FluxPipeline.load_pipeline_from_config_path(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/src/fp8/flux_pipeline.py", line 730, in load_pipeline_from_config_path
return cls.load_pipeline_from_config(config, debug=debug, shared_models=shared_models)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/src/fp8/flux_pipeline.py", line 766, in load_pipeline_from_config
return cls(
^^^^
File "/src/fp8/flux_pipeline.py", line 120, in __init__
self.compile()
File "/root/.pyenv/versions/3.11.10/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/src/fp8/flux_pipeline.py", line 194, in compile
self.generate(**warmup_dict)
File "/root/.pyenv/versions/3.11.10/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/src/fp8/flux_pipeline.py", line 639, in generate
denoised_img = self.denoise_single_item(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/src/fp8/flux_pipeline.py", line 703, in denoise_single_item
pred = self.model(
^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.10/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.10/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/src/fp8/modules/flux_model.py", line 655, in forward
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.10/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.10/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/src/fp8/modules/flux_model.py", line 365, in forward
img_mod1, img_mod2 = self.img_mod(vec)
^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.10/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.10/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/src/fp8/modules/flux_model.py", line 255, in forward
out = self.lin(self.act(vec))[:, None, :].chunk(self.multiplier, dim=-1)
^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.10/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.10/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/src/fp8/float8_quantize.py", line 279, in forward
out = torch._scaled_mm(  # noqa
^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+
Traceback (most recent call last):
  File "/root/.pyenv/versions/3.11.10/lib/python3.11/site-packages/cog/server/runner.py", line 222, in _handle_done
    f.result()
  File "/root/.pyenv/versions/3.11.10/lib/python3.11/concurrent/futures/_base.py", line 449, in result
    return self.__get_result()
           ^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.10/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
    raise self._exception
cog.server.exceptions.FatalWorkerException: Predictor errored during setup: torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+

Do I need to ugprade cuda or go for a "higher" GPU? Thanks for your work