Nerogar / OneTrainer

OneTrainer is a one-stop solution for all your stable diffusion training needs.
GNU Affero General Public License v3.0
1.78k stars 149 forks source link

[Bug]: SDXL DoRA training fails to start when base model weight data type is set to float8 #442

Open ikarsokolov opened 2 months ago

ikarsokolov commented 2 months ago

What happened?

When I activate Decomposed Weights (DoRA) training in "Lora" tab and have base SDXL model loaded as float8 in "model" tab the training process fails to start with RuntimeError: Promotion for Float8 Types is not supported, attempted to promote Float8_e4m3fn and Half.

If Dora toggle is deactivated training starts as usual.

What did you expect would happen?

DoRA training working.

Relevant log output

epoch:   0%|                                                                                                                                                          | 0/100 [00:26<?, ?it/s]
Traceback (most recent call last):                                                                                                                                                            
  File "/home/user/apps/OneTrainer/scripts/train.py", line 38, in <module>                                                                                                               
    main()                                                                                                                                                                                    
  File "/home/user/apps/OneTrainer/scripts/train.py", line 29, in main                                                                                                                   
    trainer.train()                                                                                                                                                                           
  File "/home/user/apps/OneTrainer/modules/trainer/GenericTrainer.py", line 575, in train                                                                                                
    model_output_data = self.model_setup.predict(self.model, batch, self.config, train_progress)                                                                                              
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                              
  File "/home/user/apps/OneTrainer/modules/modelSetup/BaseStableDiffusionXLSetup.py", line 467, in predict                                                                               
    predicted_latent_noise = model.unet(                                                                                                                                                      
                             ^^^^^^^^^^^                                                                                                                                                      
  File "/home/user/apps/OneTrainer/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl                                                       
    return self._call_impl(*args, **kwargs)                                                                                                                                                   
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                   
  File "/home/user/apps/OneTrainer/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl                                                               
    return forward_call(*args, **kwargs)                                                                                                                                                      
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                      
  File "/home/user/apps/OneTrainer/venv/src/diffusers/src/diffusers/models/unets/unet_2d_condition.py", line 1135, in forward                                                            
    emb = self.time_embedding(t_emb, timestep_cond)                                                                                                                                           
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                           
  File "/home/user/apps/OneTrainer/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl                                                       
    return self._call_impl(*args, **kwargs)                                                                                                                                                   
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                   
  File "/home/user/apps/OneTrainer/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl                                                               
    return forward_call(*args, **kwargs)                                                                                                                                                      
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                      
  File "/home/user/apps/OneTrainer/venv/src/diffusers/src/diffusers/models/embeddings.py", line 376, in forward                                                                          
    sample = self.linear_1(sample)                                                                                                                                                            
             ^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                            
  File "/home/user/apps/OneTrainer/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl                                                       
    return self._call_impl(*args, **kwargs)                                                                                                                                                   
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                   
  File "/home/user/apps/OneTrainer/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/apps/OneTrainer/modules/module/LoRAModule.py", line 374, in forward
    WP = self.orig_module.weight + (self.make_weight(A, B) * (self.alpha / self.rank))
         ~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
RuntimeError: Promotion for Float8 Types is not supported, attempted to promote Float8_e4m3fn and Half

Output of pip freeze

No response

mx commented 2 months ago

The correct way to handle this is a fairly subtle problem and might depend on the state of the Stable Diffusion ecosystem at large. It's not necessarily as simple as upcasting the float8 weights into the lora dtype, or vice versa. Sorry there's not a fast fix here.

mx commented 2 months ago

Autocast doesn't seem to work in my experiments, sadly (wrapped it in an autocast context for the train_dtype which was float16). We'd need to explicitly upcast the fp8 tensors. I'm not sure how to guarantee that's a safe operation that reflects unit scaling. Does Torch even have a standard unit scaling API? Is this implemented in an ad-hoc way by Nvidia Transformer Engine and others?

hameerabbasi commented 2 months ago

Does Torch even have a standard unit scaling API? Is this implemented in an ad-hoc way by Nvidia Transformer Engine and others?

It does.

mx commented 2 months ago

Looks like there's no standard unit scaling. It would be nice if bitsandbytes supported FP8, then everyone would just use it as standard! We're going to have to arbitrarily decide that FP8 weights need to be unscaled to fix this bug. I think anything present in the community that is actually FP8 is unscaled (to my knowledge), so this shouldn't be a big restriction. If we get future checkpoints in FP8 that were trained with Transformer Engine or something else, we'll need to revisit that. I'm hopeful that everyone just uses bitsandbytes for everything.

O-J1 commented 6 days ago

@ikarsokolov There is now a beta branch with proper support for fp8 Lora training, the branch name is fp8 please give it a try if you know how to use branches.