facebookresearch / CrypTen

A framework for Privacy Preserving Machine Learning
MIT License
1.55k stars 281 forks source link

How is BatchNorm layers handled? #472

Closed kwmaeng91 closed 1 year ago

kwmaeng91 commented 1 year ago

Dear Experts,

I am running ResNet18 through Crypten and trying to follow how the internal is working. However, I am not finding how batchnorm works, as opposed to other layers. My initial understanding is that compute nodes are dispatched under nn/module.py, at https://github.com/facebookresearch/CrypTen/blob/6ef151101668591bcfb2bbf7e7ebd39ab6db0413/crypten/nn/module.py#LL714C21-L714C21 So, I tried putting a print statement there to print node_to_compute.

What I found was like the following (omitting the full log)

...
/conv1/Conv_output_0
/relu/Relu_output_0
/maxpool/MaxPool_output_0
/layer1/layer1.0/conv1/Conv_output_0
/layer1/layer1.0/relu/Relu_output_0
/layer1/layer1.0/conv2/Conv_output_0
/layer1/layer1.0/Add_output_0
...

The log is not showing BatchNorm, which should be between the first conv and ReLU! Not sure why this is happening. At the end, I get no accuracy drop, so I don't think BatchNorm is being skipped. However, I am not sure why it is not captured by my test code.

After some more digging, it seems like onnx does not make BatchNorm as a separate node and simply merges it with conv (https://github.com/pytorch/pytorch/pull/40547), so the onnx graph does not seem to show any BN to begin with. However, when converting onnx Conv nodes to Crypten node, it doesn't seem like it is adding BN after Conv (https://github.com/facebookresearch/CrypTen/blob/6ef151101668591bcfb2bbf7e7ebd39ab6db0413/crypten/nn/module.py#L1963). I am confused how it is working correctly in Crypten without accuracy drop.

Can anyone help me understand why this might be the case? Does Crypten handle BatchNorm in a specific way unlike other operators?

Thank you!

lvdmaaten commented 1 year ago

I don't think there is anything special about how batch normalization is implemented in CrypTen. If ONNX somehow fuses the batch normalization layer into the convolution layer (which seems like a really odd thing to do for an intermediate representation like ONNX), then I doubt that will be handled correctly on the CrypTen side.

Looking at https://github.com/pytorch/pytorch/pull/40547, it appears the behavior may be different between training and evaluation mode. Does something change in your testing code if you switch between pytorch_model.train(mode=True) and pytorch_model.train(mode=False)? Are you seeing differences in the output between CrypTen and PyTorch in training and/or inference mode?

kwmaeng91 commented 1 year ago

Thanks for the answer. I digged a bit deeper and it seems like Crypten is working correctly, because ONNX just merges Conv and BN into one Conv by rescaling the Conv weight. So Crypten can just run Conv and it will be doing the same thing as the original Conv + BN.