tbung / pytorch-revnet

Implementation of the reversible residual network in pytorch
MIT License
101 stars 15 forks source link

RuntimeError: running_mean should contain 2 elements not 1 #10

Closed zwep closed 4 years ago

zwep commented 4 years ago

I cant decide where and why exactly this is happening.. but this happens when I run the code from your repo.. Probably somewhere in the model definition

tbung commented 4 years ago

Running the code in this repo without any changes in PyTorch 1.1.0 does not produce any error for me (although the code probably has some issues regarding latest PyTorch e.g. the LR scheduler is called before the optimizer).

However, the error you get usually comes from size mismatches in the input. The running_mean is a internal parameter of the BatchNorm layers. It has one entry per input dimension (or channel).

For the ResNet this is contained in the BatchNorm2d layers, for the RevNet you can see the running_mean being defined here.

If you have any more questions feel free to ask, if this solves your issue please close it.

zwep commented 4 years ago

Thanks for your reply. It was me who was the problem. Using one of your predefined parameter settings worked out fine.

Btw, now that we are here... do you think it is possible to change the 2D convolutional layers to a Transposed version? I would think that in doing so, you can create something like a Unet-Revnet combination. This would be a change in objective of course, where you want to predict certain labels, my final goal is to perform some image transformation.

tbung commented 4 years ago

I am going to close this issue for housekeeping, but I am happy to continue our discussion here.

The beauty of this architecture is, that it does not care about what you do in the blocks. If you are interested in working with invertible neural networks you might want to check out FrEIA, a framework created by the research group I am currently working with. It is much more flexible and implements a reversible architecture, that can learn non-volume-preserving mappings (unlike the RevNet architecture).