fudan-generative-vision / hallo

Hallo: Hierarchical Audio-Driven Visual Synthesis for Portrait Image Animation
https://fudan-generative-vision.github.io/hallo/
MIT License
9.49k stars 1.3k forks source link

some bugs when gradient_checkpointing is set True #162

Open progrobe opened 4 months ago

progrobe commented 4 months ago

Thank you for your open research and exploration!

It seems that there are some bugs in stage2 if gradient_checkpointingis set True. image

[Reproduction] This error occurs when runningtrain_stage2.py with multi gpus using gradient_checkpointingto save memory.

[Error Description] As is reported in the screen shot above, this error is caused by some parameters not receiving gradient in training. After checking their gradients, I find they have None grad. Distribute training will report an error when a grad-requiring parameter has None grad. By default, this gradient check will only be operated in DDP mode. Thus, when training with single gpu, it will be fine.

[Analysis] As I go through the code again, I find https://github.com/fudan-generative-vision/hallo/blob/83dd4ebf52baa27de737045773d4fc4163d7c820/hallo/models/unet_3d_blocks.py#L1386 does not support motion module when gradient_checkpointing is on, leaving denoising_unet.up_blocks.0.motion_modules parameters from AnimateDiff not used in training. The same thing happened indenoising_unet.down_blocks.3.motion_modules since https://github.com/fudan-generative-vision/hallo/blob/83dd4ebf52baa27de737045773d4fc4163d7c820/hallo/models/unet_3d_blocks.py#L916 does not support motion module while its corresponding else branch does. (This further explained why we won't have this problem when gradient_checkpointing=False ) They are exactly the same parameters reported in the error screen shot above. I think that's probably the cause of this error.

[Solution] If the analysis above is correct, simply adding the code

if motion_module is not None:
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(motion_module),
                        hidden_states.requires_grad_(),
                        temb,
                        encoder_hidden_states,
                    )

to support motion module will solve this problem. Or we can use requires_grad=False to omit the motion module in DownBlock3D and UpBlock3D if we don't need these modules here. In my experiment, these solutions work out fine.

I'm not sure whether I solved this problem correctly, since nobody reported these bugs. Does that mean I am the only one encountered these errors? Any help is appreciated!

skywalker00001 commented 2 months ago

I didn't encounter a bug report, but I do think this is a bug. Your modification seems to be correct. And another question, why is "hidden_states.requiresgrad()," in your adding code? Is "requiresgrad()"for hidden_states neccessary? Why?

progrobe commented 2 months ago

I'm not sure if "requiresgrad()"for hidden_states is necessary. I add this to prevent this issue when using multi gpus https://discuss.pytorch.org/t/use-of-torch-utils-checkpoint-checkpoint-causes-simple-model-to-diverge/116271/2. But i didn't check if its necessary.