unslothai / unsloth

Finetune Llama 3.2, Mistral, Phi, Qwen 2.5 & Gemma LLMs 2-5x faster with 80% less memory
https://unsloth.ai
Apache License 2.0
18.16k stars 1.27k forks source link

Request: Flux (Diffusion transformer) #876

Open RefractAI opened 3 months ago

RefractAI commented 3 months ago

I anticipate there will be a lot of demand to train (and infer) the new open SOTA image model "Flux". It's the top model on HF right now. It's a 12B diffusion transformer, which means it's too big to train on a consumer GPU without quantization, and is very slow as-is. The image model community hasn't done QLora training before as models have not been this big.

I appreciate image models are a little different but essentially the rest of the training loop inputs can be cached/adapted easily, so the important part is to reduce the memory use and increase performance of the 12B Transformer model in the following dummy HuggingFace diffusers code:

 transformer = FluxTransformer2DModel.from_pretrained(
              "black-forest-labs/FLUX.1-dev",
              subfolder="transformer"
          )

      model_pred = transformer(
          hidden_states=torch.randn(1, 4320, 64),
          timestep=torch.randn(1),
          guidance=torch.randn(1),
          pooled_projections=torch.randn(1, 768),
          encoder_hidden_states=torch.randn(1, 512, 4096),
          txt_ids=torch.randn(1, 512, 3),
          img_ids=torch.randn(1, 4320, 3),
          joint_attention_kwargs=None,
          return_dict=False,
      )

Model code: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py

Would it be possible to consider looking at this?

danielhanchen commented 3 months ago

Diffusion / Llava type models are next on our roadmap!

al-swaiti commented 1 month ago

you made amazing work for llama ,, wish i used lora + fine tuned flux training using unslothai

danielhanchen commented 1 month ago

@al-swaiti Yep working on them!