huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
26.32k stars 5.42k forks source link

Error in loading flux fp8 model with local transformer_flux.py file #9667

Open DhavalTaunk08 opened 1 month ago

DhavalTaunk08 commented 1 month ago

Describe the bug

Unable to use flux fp8 model from Kijai/flux-fp8 while having transformer_flux.py file in local. I have modified the scripts to remove any import error. I put some print statements in single_model_file.py to check why it is not loading the model.

Reproduction

The below code works fine.

single_model_file.py

def _get_single_file_loadable_mapping_class(cls):
    print(cls)
    diffusers_module = importlib.import_module(__name__.split(".")[0])

    for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
        loadable_class = getattr(diffusers_module, loadable_class_str)
        print(cls, loadable_class)
        print(issubclass(cls, loadable_class))
        if issubclass(cls, loadable_class):
            return loadable_class_str

    return None
from diffusers import FluxTransformer2DModel
transformer = FluxTransformer2DModel.from_single_file(
    "https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-schnell-fp8-e4m3fn.safetensors", 
    torch_dtype=torch.bfloat16
)

I am getting the below output:

<class 'diffusers.models.transformers.transformer_flux.FluxTransformer2DModel'>
<class 'diffusers.models.transformers.transformer_flux.FluxTransformer2DModel'> <class 'diffusers.models.unets.unet_stable_cascade.StableCascadeUNet'>
False
<class 'diffusers.models.transformers.transformer_flux.FluxTransformer2DModel'> <class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>
False
<class 'diffusers.models.transformers.transformer_flux.FluxTransformer2DModel'> <class 'diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL'>
False
<class 'diffusers.models.transformers.transformer_flux.FluxTransformer2DModel'> <class 'diffusers.models.controlnet.ControlNetModel'>
False
<class 'diffusers.models.transformers.transformer_flux.FluxTransformer2DModel'> <class 'diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel'>
False
<class 'diffusers.models.transformers.transformer_flux.FluxTransformer2DModel'> <class 'diffusers.models.unets.unet_motion_model.MotionAdapter'>
False
<class 'diffusers.models.transformers.transformer_flux.FluxTransformer2DModel'> <class 'diffusers.models.controlnet_sparsectrl.SparseControlNetModel'>
False
<class 'diffusers.models.transformers.transformer_flux.FluxTransformer2DModel'> <class 'diffusers.models.transformers.transformer_flux.FluxTransformer2DModel'>
True

But while using the class from my local code:

from transformer_flux import FluxTransformer2DModel
FluxTransformer2DModel.__module__ = 'diffusers.models.transformers.transformer_flux'
transformer = FluxTransformer2DModel.from_single_file(
    "https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-schnell-fp8-e4m3fn.safetensors", 
    torch_dtype=torch.bfloat16
)

It is giving me following error:

<class 'diffusers.models.transformers.transformer_flux.FluxTransformer2DModel'>
<class 'diffusers.models.transformers.transformer_flux.FluxTransformer2DModel'> <class 'diffusers.models.unets.unet_stable_cascade.StableCascadeUNet'>
False
<class 'diffusers.models.transformers.transformer_flux.FluxTransformer2DModel'> <class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>
False
<class 'diffusers.models.transformers.transformer_flux.FluxTransformer2DModel'> <class 'diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL'>
False
<class 'diffusers.models.transformers.transformer_flux.FluxTransformer2DModel'> <class 'diffusers.models.controlnet.ControlNetModel'>
False
<class 'diffusers.models.transformers.transformer_flux.FluxTransformer2DModel'> <class 'diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel'>
False
<class 'diffusers.models.transformers.transformer_flux.FluxTransformer2DModel'> <class 'diffusers.models.unets.unet_motion_model.MotionAdapter'>
False
<class 'diffusers.models.transformers.transformer_flux.FluxTransformer2DModel'> <class 'diffusers.models.controlnet_sparsectrl.SparseControlNetModel'>
False
<class 'diffusers.models.transformers.transformer_flux.FluxTransformer2DModel'> <class 'diffusers.models.transformers.transformer_flux.FluxTransformer2DModel'>
False

Traceback (most recent call last):
  File "/workspace/GarmentTransferV2/test.py", line 441, in <module>
    main(args)
  File "/workspace/GarmentTransferV2/test.py", line 368, in main
    transformer_garment = FluxTransformerGarment2DModel.from_single_file(
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/garment/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/workspace/garment/lib/python3.11/site-packages/diffusers/loaders/single_file_model.py", line 182, in from_single_file
    raise ValueError(
ValueError: FromOriginalModelMixin is currently only compatible with StableCascadeUNet, UNet2DConditionModel, AutoencoderKL, ControlNetModel, SD3Transformer2DModel, MotionAdapter, SparseControlNetModel, FluxTransformer2DModel

Any leads would be appreciated.

Logs

No response

System Info

Who can help?

@DN6 @sayakpaul

a-r-r-o-w commented 1 month ago

You cannot pass URLs to from_single_file. Could you try the following instead?

from diffusers import FluxTransformer2DModel
from huggingface_hub import hf_hub_download

safetensors_file = hf_hub_download("Kijai/flux-fp8", filename="flux1-dev-fp8-e4m3fn.safetensors")
print(safetensors_file)

transformer = FluxTransformer2DModel.from_single_file(safetensors_file, subfolder="transformer")
print(transformer.config)

Edit: My bad, you can pass URLs to from_single_file, I got confused with something else. I think you just have to pass subfolder="transformer" when initializing to make your snippet work. This is because we try to fetch the init config from the original Flux-Dev repository (as Kijai/flux-fp8 is based on Flux-Dev).

github-actions[bot] commented 1 week ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.