Closed ClashLuke closed 3 years ago
Hi @ClashLuke, just a few more questions about your experiment. It was just a single layer CNN? Did you use the memory saving option or did you disable it? Could you provide the code that you used? Does it show similar behavior on the GPU (i.e. running the profiler with cuda=True)?
What I do expect to happen with the memory saving option enabled is an increase in InvertibleCheckpointFunctionBackward run time (it will recompute the input at every step in the backward chain). Although without further information it is not much more than a guess.
1) The select comes from my dataset generation. I did not take into account that the profiler would profile everything, hence the error. MemCNN therefore uses 50% of the computation (or 75% of the model computation).
2) It's an 8 layer 1d-iRevNet with 64 inputs and 32 features.
3) Memory saving is enabled. I'm using InvertibleModuleWrapper(AdditiveCoupling(module, split_dim=2))
.
4) I'll create a minimal example and run it on both CPU and GPU once my current training finished.
I'd still go as far as saying that 75% is too much.
@ClashLuke Hi, it would be interesting to get the wall clock time for training your network for n epochs with and without the memory saving. This should give you an estimate of the overhead for using the memory saving.
System:
Profiling memcnn with torch.autograd.profiler.profile(), I got the following results after training a custom CNN (AdditiveCoupling, InvertibleModuleWrapper) for 256 batches @ 256 items per batch on CPU.
Flip
,native_batch_norm
,relu
,mul
,mkldnn_convolution
,bernuilli_
,sum
andmm
certainly come from the custom convolution. You can seeisnan
in the mix as well, due to torch.autograd.detect_anomaly being set to true. (Everything below 1% has been removed previously for brevity.)\ Removing the backwards operators as well leaves us with these two lines that are newly introduced by memcnn.Considering that, in theory, you only have two additional forward passes and the same one backward pass, I don't quite see how memcnn uses 84.47% of the computation.\ Hunting down where all the select operations happen should be a first step towards improving the performance of the library.