wangyanmeng / FedTAN

Pytorch implementation of FedTAN (federated learning algorithm tailored for batch normalization) proposed in the paper, Why Batch Normalization Damage Federated Learning on Non-IID Data.
MIT License
6 stars 1 forks source link

The partial derivative of the input of the BN layer in the code with respect to dL/dx." #2

Open fungizhang opened 5 months ago

fungizhang commented 5 months ago

grad_model_2_output = (grad_bn_11_input_0_list_2[0].to(device)

In file training.py, we can see the above code for computing the partial derivate of dL/dx. However, it's not agree with the equation I found:

20191008001513618

The first item : image

The author uses grad_bn_11_input_0_list_2[0] (i.e., dL/dx) instead of dL/dx_hat / blabla. Could you help explain it, thank you very much!

wangyanmeng commented 5 months ago

For the convenience of using automatic gradient calculation in PyTorch, as demonstrated in FedTAN/nets/Sub_ResNet20.py, we have divided the batch normalization process into five modules: BatchNorm2dInput (bn_Input), BatchNorm2dModule11 (bn_11), BatchNorm2dModule12 (bn_12), BatchNorm2dModule13 (bn_13), and BatchNorm2dModule2 (bn_2). This structure ensures that the gradients of bn_11 and bn_12 are backpropagated into bn_Input.

In bn_11, the computation process is defined as output1 = input_ - mean.unsqueeze(1). Therefore, the gradient $\frac{dL}{d(input)} = \frac{dL}{d(output1)} \cdot \frac{d(output1)}{d(input)} = \frac{dL}{d(output1)}$. Consequently, we directly use grad_bn_11_input_0_list_2[0].

fungizhang commented 5 months ago

Thank you for your response! Since _grad_bn_11_input_0_list2[0] means dL/d(output1) , why not

grad_model_2_output = (grad_bn_11_input_0_list_2[0].to(device)

but

grad_model_2_output = (grad_bn_11_input_0_list_2[0].to(device) + grad_bn_11_input_1_list_2[0].to(device).unsqueeze(1) / (batchsize height width) + grad_bn_12_input_0_list_2[0].to(device).unsqueeze(1) * (input_list_2[0] - mean_list_2[0].unsqueeze(1)) 2 / (batchsize height * width))

Looking forward to your reply, thanks!

wangyanmeng commented 3 months ago

As previously mentioned, our code implements batch normalization through five modules: BatchNorm2dInput (bn_Input), BatchNorm2dModule11 (bn_11), BatchNorm2dModule12 (bn_12), BatchNorm2dModule13 (bn_13), and BatchNorm2dModule2 (bn_2). For detailed implementation, please refer to the code.

During the backpropagation process, the gradient of grad_model_2_output is influenced by BatchNorm2dInput (bn_Input), BatchNorm2dModule11 (bn_11), and BatchNorm2dModule12 (bn_12).