ostris / ai-toolkit

Various AI scripts. Mostly Stable Diffusion stuff.
MIT License
3.54k stars 377 forks source link

Any point to use xformers with Flux? #75

Open henryfw opened 3 months ago

henryfw commented 3 months ago

First of all, thanks for this great repo! I'm running out of vram on 24gb trying to train a 128 rank Flux lora. When I install xformers and try to use it, I get an error:

Traceback (most recent call last):
  File "D:\ai-toolkit\run.py", line 90, in <module>
    main()
  File "D:\ai-toolkit\run.py", line 86, in main
    raise e
  File "D:\ai-toolkit\run.py", line 78, in main
    job.run()
  File "D:\ai-toolkit\jobs\ExtensionJob.py", line 22, in run
    process.run()
  File "D:\ai-toolkit\jobs\process\BaseSDTrainProcess.py", line 1701, in run
    loss_dict = self.hook_train_loop(batch)
  File "D:\ai-toolkit\extensions_built_in\sd_trainer\SDTrainer.py", line 1483, in hook_train_loop
    noise_pred = self.predict_noise(
  File "D:\ai-toolkit\extensions_built_in\sd_trainer\SDTrainer.py", line 891, in predict_noise
    return self.sd.predict_noise(
  File "D:\ai-toolkit\toolkit\stable_diffusion_model.py", line 1650, in predict_noise
    noise_pred = self.unet(
  File "D:\ai-toolkit\venv\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "D:\ai-toolkit\venv\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "D:\ai-toolkit\venv\lib\site-packages\diffusers\models\transformers\transformer_flux.py", line 400, in forward
    encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
  File "D:\ai-toolkit\venv\lib\site-packages\torch\_compile.py", line 31, in inner
    return disable_fn(*args, **kwargs)
  File "D:\ai-toolkit\venv\lib\site-packages\torch\_dynamo\eval_frame.py", line 600, in _fn
    return fn(*args, **kwargs)
  File "D:\ai-toolkit\venv\lib\site-packages\torch\utils\checkpoint.py", line 488, in checkpoint
    ret = function(*args, **kwargs)
  File "D:\ai-toolkit\venv\lib\site-packages\diffusers\models\transformers\transformer_flux.py", line 395, in custom_forward
    return module(*inputs)
  File "D:\ai-toolkit\venv\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "D:\ai-toolkit\venv\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "D:\ai-toolkit\venv\lib\site-packages\diffusers\models\transformers\transformer_flux.py", line 201, in forward
    attn_output, context_attn_output = self.attn(
ValueError: not enough values to unpack (expected 2, got 1)
zejacky commented 3 months ago

Hello @henryfw So far I was reading , Xformers can have compatibility issues with flux1-dev model. I'm currently using CUDA 12.4 and PyTorch 2.4.0 combination, without xformers. The same for comfyui. The training with a person was successful so far (ca. 1h, 24 minutes)