I meet some bugs when running train_stage2.py with multi gpus using gradient_checkpointing. Then following error occurs when unconditional audio forward is running.
[Reproduction]
This error will occur when runningtrain_stage2.py with multi gpus using gradient_checkpointing and uncond_audio_fwd is True.
[Error Description]
As is reported in the screen shot above, this error is directly caused by some parameters not receiving gradient in training. After checking their gradients, I find they have None grad. Distribute training report an error when a grad-requiring parameter has None grad. By default, this gradient check will only operate in DDP mode. Thus, when training with single gpu, it will be fine.
The way that gradient checkpointing is implemented, whether the output tensor produced by the checkpointed module has gradient or not is solely determined by requires_grad value of its input tensor. In other word, even if model parameters are trainable and have attribute requires_grad=True, the output tensor will not require grads iff input tensor does not require grad. This will further make model parameters wraped in this gradient checkpointed module not receiving any grads.
Ref: https://discuss.pytorch.org/t/use-of-torch-utils-checkpoint-checkpoint-causes-simple-model-to-diverge/116271
Caused by 1 and 2, a lot of parameters of the model are not getting their grads properly.
part of the parameters that lost their grads are listed here:
audioproj module is trainable and requires grads, but it's skipped and can not receive any grads in backprop. This will cause an error in the way described in Error Description 1.
More trainable part of the net will not receive grads in the way discribed in Error Description 2, since audio_emb does not require grads in unconditional audio forward. By contrast, in conditional forward, audioproj requires grads making audio_emb also requires grads and thus avoid the problem.
[Solution]
One possible solution is that we can turn off unconditional audio forward by setting uncond_audio_ratio: 0
Another possible solution is setting requires_grad=True for audio_emb in unconditional audio forward. This could only handle the problem discussed in Analysis 2, leaving problem in Analysis 1 unsolved.
I've tested these 2 solutions, and they work as expected.
I'm not sure whether I'm doing the right thing, since nobody reported these bugs. Does that mean I am the only one encountered these errors? I Really appreciated your help!
Thank you for your open research and exploration!
I meet some bugs when running
train_stage2.py
with multi gpus usinggradient_checkpointing
. Then following error occurs when unconditional audio forward is running.[Reproduction] This error will occur when running
train_stage2.py
with multi gpus usinggradient_checkpointing
anduncond_audio_fwd
is True.[Error Description]
requires_grad
value of its input tensor. In other word, even if model parameters are trainable and have attributerequires_grad=True
, the output tensor will not require grads iff input tensor does not require grad. This will further make model parameters wraped in this gradient checkpointed module not receiving any grads. Ref: https://discuss.pytorch.org/t/use-of-torch-utils-checkpoint-checkpoint-causes-simple-model-to-diverge/116271[Analysis] This error only occurs in unconditional audio forward, in which audio_emb is set to zeros: https://github.com/fudan-generative-vision/hallo/blob/83dd4ebf52baa27de737045773d4fc4163d7c820/scripts/train_stage2.py#L164 However, it might cause 2 problems.
[Solution] One possible solution is that we can turn off unconditional audio forward by setting
uncond_audio_ratio: 0
Another possible solution is settingrequires_grad=True
for audio_emb in unconditional audio forward. This could only handle the problem discussed in Analysis 2, leaving problem in Analysis 1 unsolved. I've tested these 2 solutions, and they work as expected.I'm not sure whether I'm doing the right thing, since nobody reported these bugs. Does that mean I am the only one encountered these errors? I Really appreciated your help!