a-r-r-o-w / cogvideox-factory

Memory optimized finetuning scripts for CogVideoX & Mochi using TorchAO and DeepSpeed
Apache License 2.0
472 stars 44 forks source link

cannot access local variable 'gradient_norm_before_clip' where it is not associated with a value #41

Open Yuancheng-Xu opened 1 month ago

Yuancheng-Xu commented 1 month ago

During both I2V and t2V training, sometimes I encountered the error

[rank1]:   File "/root/projects/cogvideox-factory/training/cogvideox_text_to_video_lora.py", line 762, in main
[rank1]:     "gradient_norm_before_clip": gradient_norm_before_clip,
[rank1]:                                  ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: UnboundLocalError: cannot access local variable 'gradient_norm_before_clip' where it is not associated with a value

This is probably here in the following code

if accelerator.sync_gradients:
        gradient_norm_before_clip = get_gradient_norm(transformer.parameters())
        accelerator.clip_grad_norm_(transformer.parameters(), args.max_grad_norm)
        gradient_norm_after_clip = get_gradient_norm(transformer.parameters())

somehow accelerator.sync_gradients is false sometimes.

Is there a quick fix? Is it only for logging?

a-r-r-o-w commented 1 month ago

Hmm, that is very weird... It means that there is no gradient update for an entire epoch. I think it could be the case for training with very low amount of videos? How many videos were you using when this occured, what was your batch size and gradient accumulation steps?

Yuancheng-Xu commented 1 month ago

For example, using 2 GPUs, which is something like following

***** Running training *****                                                                                                                                                                                
  Num trainable parameters = 132120576                                                                                                                                                                      
  Num examples = 128                                                                                                                                                                                        
  Num batches each epoch = 64                                                                                                                                                                               
  Num epochs = 5                                                                                                                                                                                            
  Instantaneous batch size per device = 1                                                                                                                                                                   
  Total train batch size (w. parallel, distributed & accumulation) = 32                                                                                                                                     
  Gradient accumulation steps = 16                                                                                                                                                                          
  Total optimization steps = 20   

Same error when using 8GPUs

***** Running training *****                                                                                                                                                                                
  Num trainable parameters = 132120576                                                                                                                                                                                         
  Num examples = 128                                                                                                                                                                                        
  Num batches each epoch = 16                                                                                                                                                                               
  Num epochs = 2                                                                                                                                                                                            
  Instantaneous batch size per device = 1                                                                                                                                                                   
  Total train batch size (w. parallel, distributed & accumulation) = 32                                                                                                                                     
  Gradient accumulation steps = 4                                                                                                                                                                           
  Total optimization steps = 5   

I don't think it is the issue of low amount of training samples?

a-r-r-o-w commented 1 month ago

Do you see the error when using gradient accumulation steps as 1? I haven't really experimented with higher gradient_accumulation_steps yet, so I might have missed something. Most of my experiments use train_batch_size=4 and 2/4 GPUs with 100 videos.

The script can be improved a little by only logging gradnorm if it was computed. However, I fail to see how no gradient step occurs with the configurations you shared. Will take a deeper look over the weekend

Yuancheng-Xu commented 1 month ago

Yep I just confirm that, every thing works fine with gradient accumulation steps=1. As long as it is not 1, the error occurs.

a-r-r-o-w commented 1 month ago

Okay thanks, that helps. Will try debugging over the weekend

a-r-r-o-w commented 1 month ago

@Yuancheng-Xu I found the issue, thanks for reporting! Somehow, everytime I looked at the code before today, I thought that gradient norm calculation was happening after each epoch instead of each step. It is correct to do it at each step, but only when the gradient sync happens - which is now addressed in the PR. I'm running an experiment to verify that it works at the moment, so will merge the PR later on. Please give it a review too when free!

a-r-r-o-w commented 1 month ago

I think there might still be some other mistakes or a source of randomness coming from somewhere, because for the exact same training parameters, seeds and initializations, and only varying train_batch_size and gradient_accumulation_steps, I can't get the loss curves to match. Validations for both look healthy so far, but will wait for the experiments to complete to see if it's working correctly now. It is most likely a source of randomness that we forgot seeding (already see that we don't seed the normal distribution for create noisy image latents)

image

image

Yuancheng-Xu commented 1 month ago

Thanks @a-r-r-o-w ! I left a comment here.

As for the actual training outcomes (learning curve, validation generated samples), I would expect the results stay the same.