silvandeleemput / memcnn

PyTorch Framework for Developing Memory Efficient Deep Invertible Networks
MIT License
251 stars 26 forks source link

perf: Computation profiling and optimization #52

Closed ClashLuke closed 3 years ago

ClashLuke commented 4 years ago

System:

Profiling memcnn with torch.autograd.profiler.profile(), I got the following results after training a custom CNN (AdditiveCoupling, InvertibleModuleWrapper) for 256 batches @ 256 items per batch on CPU.

---------------------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Name                                           Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     Number of Calls  
---------------------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
flip                                           1.70%            38.087s          1.70%            38.087s          901.693us        42240            
native_batch_norm                              9.80%            219.675s         9.80%            219.675s         3.250ms          67584            
select                                         33.77%           757.009s         33.77%           757.009s         62.078us         12194560         
relu                                           1.04%            23.237s          1.04%            23.237s          687.644us        33792            
mul                                            5.01%            112.274s         5.01%            112.274s         1.388ms          80896            
mkldnn_convolution                             12.47%           279.614s         12.47%           279.614s         4.965ms          56320            
bernoulli_                                     3.20%            71.749s          3.20%            71.749s          3.185ms          22528            
sum                                            2.81%            62.890s          2.81%            62.890s          2.275ms          27648            
mm                                             1.04%            23.246s          1.04%            23.246s          840.786us        27648            
isnan                                          3.99%            89.445s          3.99%            89.445s          221.978us        402944           
native_batch_norm_backward                     4.09%            91.756s          4.09%            91.756s          3.734ms          24576            
mkldnn_convolution_backward                    8.64%            193.744s         8.64%            193.744s         9.460ms          20480            
InvertibleCheckpointFunctionBackward           1.22%            27.365s          50.70%           1136.456s        634.183ms        1792             
---------------------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Self CPU time total: 2241.458s

Flip, native_batch_norm, relu, mul, mkldnn_convolution, bernuilli_, sum and mm certainly come from the custom convolution. You can see isnan in the mix as well, due to torch.autograd.detect_anomaly being set to true. (Everything below 1% has been removed previously for brevity.)\ Removing the backwards operators as well leaves us with these two lines that are newly introduced by memcnn.

select                                         33.77%           757.009s         33.77%           757.009s         62.078us         12194560         
InvertibleCheckpointFunctionBackward           1.22%            27.365s          50.70%           1136.456s        634.183ms        1792             

Considering that, in theory, you only have two additional forward passes and the same one backward pass, I don't quite see how memcnn uses 84.47% of the computation.\ Hunting down where all the select operations happen should be a first step towards improving the performance of the library.

silvandeleemput commented 4 years ago

Hi @ClashLuke, just a few more questions about your experiment. It was just a single layer CNN? Did you use the memory saving option or did you disable it? Could you provide the code that you used? Does it show similar behavior on the GPU (i.e. running the profiler with cuda=True)?

What I do expect to happen with the memory saving option enabled is an increase in InvertibleCheckpointFunctionBackward run time (it will recompute the input at every step in the backward chain). Although without further information it is not much more than a guess.

ClashLuke commented 4 years ago

1) The select comes from my dataset generation. I did not take into account that the profiler would profile everything, hence the error. MemCNN therefore uses 50% of the computation (or 75% of the model computation).

2) It's an 8 layer 1d-iRevNet with 64 inputs and 32 features.

3) Memory saving is enabled. I'm using InvertibleModuleWrapper(AdditiveCoupling(module, split_dim=2)).

4) I'll create a minimal example and run it on both CPU and GPU once my current training finished.

I'd still go as far as saying that 75% is too much.

silvandeleemput commented 4 years ago

@ClashLuke Hi, it would be interesting to get the wall clock time for training your network for n epochs with and without the memory saving. This should give you an estimate of the overhead for using the memory saving.